| 
						
						
							
								
							
						
						
					 | 
					@ -102,8 +102,10 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.train_surf_ncs = sampled_sdf_data | 
					 | 
					 | 
					            self.train_surf_ncs = sampled_sdf_data | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.sdf_data = surface_sdf_data | 
					 | 
					 | 
					            self.sdf_data = surface_sdf_data | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        print_data_distribution(self.sdf_data) | 
					 | 
					 | 
					        logger.print_tensor_stats("sdfd_data",self.sdf_data) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        logger.debug(self.sdf_data.shape) | 
					 | 
					 | 
					        logger.debug(self.sdf_data.shape) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.print_tensor_stats("train_surf_ncs",self.train_surf_ncs[:,0:3]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.debug(self.train_surf_ncs.shape) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.gpu_memory_stats("SDF数据准备后") | 
					 | 
					 | 
					        logger.gpu_memory_stats("SDF数据准备后") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 初始化数据集 | 
					 | 
					 | 
					        # 初始化数据集 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #self.brep_data = load_brep_file(self.config.data.pkl_path) | 
					 | 
					 | 
					        #self.brep_data = load_brep_file(self.config.data.pkl_path) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -229,22 +231,9 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        _gt_sdf = self.train_surf_ncs[:, -1].clone().detach()  # 取最后一列作为SDF真值 | 
					 | 
					 | 
					        _gt_sdf = self.train_surf_ncs[:, -1].clone().detach()  # 取最后一列作为SDF真值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 检查是否需要重新计算缓存 | 
					 | 
					 | 
					        # 检查是否需要重新计算缓存 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if epoch % 10 == 1 or self.cached_train_data is None: | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					            # 计算流形点的掩码和操作符 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 生成非流形点 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 更新缓存 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.cached_train_data = { | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                "nonmnfld_pnts": _nonmnfld_pnts, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                "psdf": _psdf, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            } | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 从缓存中读取数据 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            _psdf = self.cached_train_data["psdf"] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) | 
					 | 
					 | 
					        logger.debug((_mnfld_pnts)) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -260,22 +249,14 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pnts = _mnfld_pnts[start_idx:end_idx]  # 流形点 | 
					 | 
					 | 
					            mnfld_pnts = _mnfld_pnts[start_idx:end_idx]  # 流形点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            gt_sdf = _gt_sdf[start_idx:end_idx]  # SDF真值 | 
					 | 
					 | 
					            gt_sdf = _gt_sdf[start_idx:end_idx]  # SDF真值 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            normals = _normals[start_idx:end_idx] if args.use_normal else None  # 法线 | 
					 | 
					 | 
					            normals = _normals[start_idx:end_idx] if args.use_normal else None  # 法线 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 非流形点使用缓存数据(整个batch共享) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            psdf = _psdf[start_idx:end_idx] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # --- 准备模型输入,启用梯度 --- | 
					 | 
					 | 
					            # --- 准备模型输入,启用梯度 --- | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
					 | 
					 | 
					            mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # --- 前向传播 --- | 
					 | 
					 | 
					            # --- 前向传播 --- | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mnfld_pred = self.model.forward_background( | 
					 | 
					 | 
					            mnfld_pred = self.model.forward_background( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                mnfld_pnts | 
					 | 
					 | 
					                mnfld_pnts | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					            ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pred = self.model.forward_background( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                nonmnfld_pnts | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -292,14 +273,11 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") | 
					 | 
					 | 
					                    #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
					 | 
					 | 
					                    #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.gpu_memory_stats("计算损失前") | 
					 | 
					 | 
					                    logger.gpu_memory_stats("计算损失前") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    loss, loss_details = self.loss_manager.compute_loss( | 
					 | 
					 | 
					                    loss, loss_details = self.loss_manager.compute_loss_stage1( | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					                        mnfld_pnts, | 
					 | 
					 | 
					                        mnfld_pnts, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        nonmnfld_pnts, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        normals, # 传递检查过的 normals | 
					 | 
					 | 
					                        normals, # 传递检查过的 normals | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        gt_sdf, | 
					 | 
					 | 
					                        gt_sdf, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        mnfld_pred, | 
					 | 
					 | 
					                        mnfld_pred, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        nonmnfld_pred, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        psdf | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    ) | 
					 | 
					 | 
					                    ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                else: | 
					 | 
					 | 
					                else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
					 | 
					 | 
					                    loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -307,7 +285,6 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                # --- 4. 检查损失计算结果 --- | 
					 | 
					 | 
					                # --- 4. 检查损失计算结果 --- | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                if self.debug_mode: | 
					 | 
					 | 
					                if self.debug_mode: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.print_tensor_stats("psdf",psdf) | 
					 | 
					 | 
					                    logger.print_tensor_stats("psdf",psdf) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    if check_tensor(loss, "Calculated Loss", epoch, step): | 
					 | 
					 | 
					                    if check_tensor(loss, "Calculated Loss", epoch, step): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") | 
					 | 
					 | 
					                        logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        if loss_details: logger.error(f"Loss Details: {loss_details}") | 
					 | 
					 | 
					                        if loss_details: logger.error(f"Loss Details: {loss_details}") | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -917,7 +894,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info(f"Loaded model from {args.resume_checkpoint_path}") | 
					 | 
					 | 
					            logger.info(f"Loaded model from {args.resume_checkpoint_path}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # stage1 | 
					 | 
					 | 
					        # stage1 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.model.encoder.freeze_stage1() | 
					 | 
					 | 
					        self.model.freeze_stage2() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): | 
					 | 
					 | 
					        for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 训练一个epoch | 
					 | 
					 | 
					            # 训练一个epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            train_loss = self.train_epoch_stage1(epoch) | 
					 | 
					 | 
					            train_loss = self.train_epoch_stage1(epoch) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -929,6 +906,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                self._save_checkpoint(epoch, train_loss) | 
					 | 
					 | 
					                self._save_checkpoint(epoch, train_loss) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                logger.info(f'Checkpoint saved at epoch {epoch}') | 
					 | 
					 | 
					                logger.info(f'Checkpoint saved at epoch {epoch}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        start_epoch=max(start_epoch, self.config.train.num_epochs1) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # stage2 freeze_stage2 | 
					 | 
					 | 
					        # stage2 freeze_stage2 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 | 
					 | 
					 | 
					        max_stage2_epoch = self.config.train.num_epochs1+self.config.train.num_epochs2 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if start_epoch < max_stage2_epoch: | 
					 | 
					 | 
					        if start_epoch < max_stage2_epoch: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -942,9 +920,10 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            cur_epoch = start_epoch | 
					 | 
					 | 
					            cur_epoch = start_epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #stage 3 | 
					 | 
					 | 
					        #stage 3 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.model.encoder.unfreeze() | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.scheduler.reset() | 
					 | 
					 | 
					        self.scheduler.reset() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        for epoch in range(cur_epoch, max_stage2_epoch + self.config.train.num_epochs3 + 1): | 
					 | 
					 | 
					        #self.model.freeze_stage2() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.model.unfreeze() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 训练一个epoch | 
					 | 
					 | 
					            # 训练一个epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            train_loss = self.train_epoch_stage3(epoch) | 
					 | 
					 | 
					            train_loss = self.train_epoch_stage3(epoch) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            #train_loss = self.train_epoch_stage2(epoch) | 
					 | 
					 | 
					            #train_loss = self.train_epoch_stage2(epoch) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |