| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -13,6 +13,7 @@ from brep2sdf.networks.octree import OctreeNode | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.networks.loss import LossManager | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.networks.patch_graph import PatchGraph | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.networks.sample import NormalPerPoint | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.networks.learning_rate import LearningRateScheduler | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -140,6 +141,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            weight_decay=config.train.weight_decay | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #self.scheduler = LearningRateScheduler(self.conf.get_list('train.learning_rate_schedule'), self.conf.get_float('train.weight_decay'), self.model.parameters()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.loss_manager = LossManager(ablation="none") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.gpu_memory_stats("训练器初始化后") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -222,6 +225,7 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            normals = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if args.use_normal: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.debug(normals) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_points.requires_grad_(True)  # 在检查之后启用梯度 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -232,6 +236,9 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats("mnfld_pred",mnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 检查前向传播的输出 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.gpu_memory_stats("前向传播后") | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -302,7 +309,7 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.train() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        total_loss = 0.0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        step = 0 # 如果你的训练是分批次的,这里应该用批次索引 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size = 10240  # 设置合适的batch大小 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        batch_size = 8192  # 设置合适的batch大小 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 将数据分成多个batch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_points = self.sdf_data.shape[0] | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -311,7 +318,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for batch_idx in range(num_batches): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            start_idx = batch_idx * batch_size | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            end_idx = min((batch_idx + 1) * batch_size, num_points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0)  # 生成非流形点 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach()  # 取最后一列作为SDF真值 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            normals = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if args.use_normal: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -322,19 +330,28 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 执行检查 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if  self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if check_tensor(points, "Input Points", epoch, step): return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if args.use_normal: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    # 只有在请求法线时才检查 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    if check_tensor(normals, "Input Normals", epoch, step): return float('inf') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.debug(normals) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.print_tensor_stats("normals-x",normals[0]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.print_tensor_stats("normals-y",normals[1]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.print_tensor_stats("normals-z",normals[2]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # --- 前向传播 --- | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf = self.model(points) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mnfld_pred = self.model(mnfld_pnts) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nonmnfld_pred = self.model(nonmnfld_pnts) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats("mnfld_pred",mnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.debug_mode: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # --- 检查前向传播的输出 --- | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -356,10 +373,12 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        gt_sdf, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        pred_sdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        mnfld_pred, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        nonmnfld_pred | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -448,8 +467,8 @@ class Trainer: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for epoch in range(start_epoch, self.config.train.num_epochs + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 训练一个epoch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_loss = self.train_epoch_stage1(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #train_loss = self.train_epoch_stage1(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_loss = self.train_epoch(epoch) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 验证 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ''' | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |