| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -443,37 +443,63 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 累积损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                losses.append(loss) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if epoch % 1 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if epoch % 100 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_detailss.append(loss_details) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 多个损失平均后反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            loss_tensor = torch.stack(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mean_loss = (loss_tensor * weights).sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mean_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if len(losses) % 30 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 多个损失平均后反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_tensor = torch.stack(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    mean_loss = (loss_tensor * weights).sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    mean_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.step(mean_loss, epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.scheduler.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.scheduler.step(mean_loss, epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.scheduler.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    self.scheduler.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 如果你想查看详细的损失信息,可以在这里添加日志记录 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if epoch % 100 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                    f'Loss: {loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        loss_details_tensor = torch.stack(loss_detailss)  # shape: [num_patches, 5] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        # 对每个子项取加权平均(如果需要 weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    losses = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_detailss = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if len(losses) > 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 多个损失平均后反向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss_tensor = torch.stack(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mean_loss = (loss_tensor * weights).sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mean_loss.backward() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 更新参数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.optimizer.step() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.step(mean_loss, epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 清空梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.scheduler.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 如果你想查看详细的损失信息,可以在这里添加日志记录 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if epoch % 100 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                f'Loss: {loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss_details_tensor = torch.stack(loss_detailss)  # shape: [num_patches, 5] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 清理缓存 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            torch.cuda.empty_cache() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 如果你想查看详细的损失信息,可以在这里添加日志记录 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if epoch % 1 == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                            f'Loss: {loss:.6f}') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                loss_details_tensor = torch.stack(loss_detailss)  # shape: [num_patches, 5] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 对每个子项取加权平均(如果需要 weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 对每个子项取加权平均(如果需要 weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        avg_loss = sum(losses) / len(losses) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |