| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -197,8 +197,17 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # return torch.cat([global_min, global_max]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # return [-0.5,] # 这个是错误的 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_epoch_stage1(self, epoch: int): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0  # 初始化总损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for step, surf_points in enumerate(self.data['surf_ncs']):  # 定义 step 变量 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss_details = { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "manifold": 0.0, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "normals": 0.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        accumulated_loss = 0.0  # 新增:用于累积多个step的loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 新增:在每个epoch开始时清零梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for step, surf_points in enumerate(self.data['surf_ncs']): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = torch.tensor(surf_points, device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = torch.zeros(points.shape[0], device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            normals = None | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -211,14 +220,13 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model.forward_training_volumes(points, step) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.debug(f"pred_sdf:{pred_sdf}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 检查前向传播的输出 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.gpu_memory_stats("前向传播后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 计算损失 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss = torch.tensor(float('nan'), device=self.device)  # 初始化为 NaN 以防计算失败 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss_details = {} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if args.use_normal: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -229,45 +237,45 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_details = {}  # 确保变量初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if check_tensor(loss, "Calculated Loss", epoch, step): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        if loss_details: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            logger.error(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        return float('inf')  # 如果损失无效,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as loss_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return float('inf')  # 如果计算出错,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 反向传播和优化 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- (推荐) 添加梯度裁剪 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # 范数裁剪 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 修改:累积loss而不是立即backward | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                accumulated_loss += loss / self.config.train.accumulation_steps  # 假设配置中有accumulation_steps | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                current_loss = loss.item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                total_loss += current_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                for key in total_loss_details: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if key in loss_details: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        total_loss_details[key] += loss_details[key].item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 新增:达到累积步数时执行反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if (step + 1) % self.config.train.accumulation_steps == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    accumulated_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as backward_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return float('inf')  # 如果反向传播或优化出错,停止这个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    accumulated_loss = 0.0  # 重置累积loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 记录和累加损失 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_loss = loss.item() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not np.isfinite(current_loss):  # 再次确认损失是有效的数值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 记录日志保持不变 ... | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            total_loss += current_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            except Exception as loss_e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.error(f"Error in step {step}: {loss_e}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                continue | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 内存管理 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            del loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 记录训练进度 (只记录有效的损失) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    f'Loss: {current_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 新增:处理最后未达到累积步数的剩余loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if accumulated_loss != 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            accumulated_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 计算并记录epoch损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    f'Loss: {total_loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Loss Details: {total_loss_details}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return total_loss  # 返回平均损失而非累计值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_epoch(self, epoch: int) -> float: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |