| 
						
						
							
								
							
						
						
					 | 
					@ -12,6 +12,7 @@ from brep2sdf.networks.network import Net | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.networks.octree import OctreeNode | 
					 | 
					 | 
					from brep2sdf.networks.octree import OctreeNode | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.networks.loss import LossManager | 
					 | 
					 | 
					from brep2sdf.networks.loss import LossManager | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.networks.patch_graph import PatchGraph | 
					 | 
					 | 
					from brep2sdf.networks.patch_graph import PatchGraph | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from brep2sdf.networks.sample import NormalPerPoint | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -142,6 +143,11 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.loss_manager = LossManager(ablation="none") | 
					 | 
					 | 
					        self.loss_manager = LossManager(ablation="none") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.gpu_memory_stats("训练器初始化后") | 
					 | 
					 | 
					        logger.gpu_memory_stats("训练器初始化后") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.sampler = NormalPerPoint( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            global_sigma=0.1,  # 全局采样标准差 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            local_sigma=0.01  # 局部采样标准差 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f"初始化完成,正在处理模型 {self.model_name}") | 
					 | 
					 | 
					        logger.info(f"初始化完成,正在处理模型 {self.model_name}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -200,7 +206,9 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        total_loss = 0.0 | 
					 | 
					 | 
					        total_loss = 0.0 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        total_loss_details = { | 
					 | 
					 | 
					        total_loss_details = { | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            "manifold": 0.0, | 
					 | 
					 | 
					            "manifold": 0.0, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            "normals": 0.0 | 
					 | 
					 | 
					            "normals": 0.0, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            "eikonal": 0.0, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            "offsurface": 0.0 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        } | 
					 | 
					 | 
					        } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        accumulated_loss = 0.0  # 新增:用于累积多个step的loss | 
					 | 
					 | 
					        accumulated_loss = 0.0  # 新增:用于累积多个step的loss | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -208,19 +216,21 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.optimizer.zero_grad() | 
					 | 
					 | 
					        self.optimizer.zero_grad() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for step, surf_points in enumerate(self.data['surf_ncs']): | 
					 | 
					 | 
					        for step, surf_points in enumerate(self.data['surf_ncs']): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            points = torch.tensor(surf_points, device=self.device) | 
					 | 
					 | 
					            mnfld_points = torch.tensor(surf_points, device=self.device) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            gt_sdf = torch.zeros(points.shape[0], device=self.device) | 
					 | 
					 | 
					            nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0)  # 生成非流形点 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            normals = None | 
					 | 
					 | 
					            normals = None | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if args.use_normal: | 
					 | 
					 | 
					            if args.use_normal: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) | 
					 | 
					 | 
					                normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # --- 准备模型输入,启用梯度 --- | 
					 | 
					 | 
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            points.requires_grad_(True)  # 在检查之后启用梯度 | 
					 | 
					 | 
					            mnfld_points.requires_grad_(True)  # 在检查之后启用梯度 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            nonmnfld_pnts.requires_grad_(True)  # 在检查之后启用梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # --- 前向传播 --- | 
					 | 
					 | 
					            # --- 前向传播 --- | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.optimizer.zero_grad() | 
					 | 
					 | 
					            self.optimizer.zero_grad() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            pred_sdf = self.model.forward_training_volumes(points, step) | 
					 | 
					 | 
					            mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            logger.debug(f"pred_sdf:{pred_sdf}") | 
					 | 
					 | 
					            nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if self.debug_mode: | 
					 | 
					 | 
					            if self.debug_mode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                # --- 检查前向传播的输出 --- | 
					 | 
					 | 
					                # --- 检查前向传播的输出 --- | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -230,10 +240,12 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            try: | 
					 | 
					 | 
					            try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if args.use_normal: | 
					 | 
					 | 
					                if args.use_normal: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    loss, loss_details = self.loss_manager.compute_loss( | 
					 | 
					 | 
					                    loss, loss_details = self.loss_manager.compute_loss( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        points, | 
					 | 
					 | 
					                        mnfld_points, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        nonmnfld_pnts, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        normals, | 
					 | 
					 | 
					                        normals, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        gt_sdf, | 
					 | 
					 | 
					                        gt_sdf, | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                        pred_sdf | 
					 | 
					 | 
					                        mnfld_pred, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                        nonmnfld_pred | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    ) | 
					 | 
					 | 
					                    ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                else: | 
					 | 
					 | 
					                else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -278,6 +290,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return total_loss  # 返回平均损失而非累计值 | 
					 | 
					 | 
					        return total_loss  # 返回平均损失而非累计值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def train_epoch(self, epoch: int) -> float: | 
					 | 
					 | 
					    def train_epoch(self, epoch: int) -> float: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # --- 1. 检查输入数据 --- | 
					 | 
					 | 
					        # --- 1. 检查输入数据 --- | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) | 
					 | 
					 | 
					        # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |