| 
						
						
							
								
							
						
						
					 | 
					@ -147,6 +147,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) | 
					 | 
					 | 
					        self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.loss_manager = LossManager(ablation="none") | 
					 | 
					 | 
					        self.loss_manager = LossManager(ablation="none") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.best_loss = float('inf') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.gpu_memory_stats("训练器初始化后") | 
					 | 
					 | 
					        logger.gpu_memory_stats("训练器初始化后") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.sampler = NormalPerPoint( | 
					 | 
					 | 
					        self.sampler = NormalPerPoint( | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -324,8 +325,8 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info(f'Train Epoch: {epoch:4d}]\t' | 
					 | 
					 | 
					            logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                        f'Loss: {current_loss:.6f}') | 
					 | 
					 | 
					                        f'Loss: {current_loss:.6f}') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if loss_details: logger.info(f"Loss Details: {loss_details}") | 
					 | 
					 | 
					            if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
					 | 
					 | 
					            #dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            dot.render("forward_graph1", format="png")  # 这会保存计算图为png格式 | 
					 | 
					 | 
					            #dot.render("forward_graph1", format="png")  # 这会保存计算图为png格式 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            return total_loss # 对于单批次训练,直接返回当前损失 | 
					 | 
					 | 
					            return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -358,7 +359,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                logger.warning(f"Patch {patch_id} has no valid points.") | 
					 | 
					 | 
					                logger.warning(f"Patch {patch_id} has no valid points.") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                continue | 
					 | 
					 | 
					                continue | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6])  # 生成非流形点 | 
					 | 
					 | 
					            nonmnfld_pnts, psdf = self.sampler.get_norm_points(points[:,0:3], points[:,3:6],0.1)  # 生成非流形点 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            all_points.append(points) | 
					 | 
					 | 
					            all_points.append(points) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            valid_patch_ids.append(patch_id) | 
					 | 
					 | 
					            valid_patch_ids.append(patch_id) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nonmnfld_pnts_list.append(nonmnfld_pnts) | 
					 | 
					 | 
					            nonmnfld_pnts_list.append(nonmnfld_pnts) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -480,8 +481,8 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
					 | 
					 | 
					                    subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                    logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
					 | 
					 | 
					                    logger.info("  ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) | 
					 | 
					 | 
					                    #dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    dot.render("forward_graph2", format="png")  # 这会保存计算图为png格式 | 
					 | 
					 | 
					                    #dot.render("forward_graph2", format="png")  # 这会保存计算图为png格式 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        avg_loss = sum(losses) / len(losses) | 
					 | 
					 | 
					        avg_loss = sum(losses) / len(losses) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -607,7 +608,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) | 
					 | 
					 | 
					            _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 生成非流形点 | 
					 | 
					 | 
					            # 生成非流形点 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) | 
					 | 
					 | 
					            _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals, 0.1) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) | 
					 | 
					 | 
					            _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 更新缓存 | 
					 | 
					 | 
					            # 更新缓存 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -664,8 +665,8 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                _nonmnfld_face_indices_mask[start_idx:end_idx], | 
					 | 
					 | 
					                _nonmnfld_face_indices_mask[start_idx:end_idx], | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                _nonmnfld_operator[start_idx:end_idx] | 
					 | 
					 | 
					                _nonmnfld_operator[start_idx:end_idx] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					            ) | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
					 | 
					 | 
					            #dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            dot.render("forward_graph3", format="png")  # 这会保存计算图为png格式 | 
					 | 
					 | 
					            #dot.render("forward_graph3", format="png")  # 这会保存计算图为png格式 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            #logger.print_tensor_stats("psdf",psdf) | 
					 | 
					 | 
					            #logger.print_tensor_stats("psdf",psdf) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) | 
					 | 
					 | 
					            #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -735,9 +736,9 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 记录训练进度 (只记录有效的损失) | 
					 | 
					 | 
					        # 记录训练进度 (只记录有效的损失) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f'Train Epoch: {epoch:4d}]\t' | 
					 | 
					 | 
					        logger.info(f'Train Epoch: {epoch:4d}]\t' | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					                    f'Loss: {current_loss:.6f}') | 
					 | 
					 | 
					                    f'Loss: {total_loss:.6f}') | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        if loss_details: logger.info(f"Loss Details: {loss_details}") | 
					 | 
					 | 
					        if loss_details: logger.info(f"Loss Details: {loss_details}") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					
 | 
					 | 
					 | 
					        self.validate(epoch,total_loss) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return total_loss # 对于单批次训练,直接返回当前损失 | 
					 | 
					 | 
					        return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -873,25 +874,15 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return total_loss # 对于单批次训练,直接返回当前损失 | 
					 | 
					 | 
					        return total_loss # 对于单批次训练,直接返回当前损失 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def validate(self, epoch: int) -> float: | 
					 | 
					 | 
					    def validate(self, epoch, loss): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.model.eval() | 
					 | 
					 | 
					        if loss < self.best_loss: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        total_loss = 0.0 | 
					 | 
					 | 
					            self.best_loss = loss | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            self._save_checkpoint(-1, loss) # 存 best | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        with torch.no_grad(): | 
					 | 
					 | 
					            logger.info(f'Best Epoch: {epoch}\tAverage Loss: {loss:.6f}') | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            for batch in self.val_loader: | 
					 | 
					 | 
					        return  | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					                points = batch['points'].to(self.device) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                gt_sdf = batch['sdf'].to(self.device) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                pred_sdf = self.model(points) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                total_loss += loss.item() | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        avg_loss = total_loss / len(self.val_loader) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return avg_loss | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def train(self): | 
					 | 
					 | 
					    def train(self): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        best_val_loss = float('inf') | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info("Starting training...") | 
					 | 
					 | 
					        logger.info("Starting training...") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        start_time = time.time() | 
					 | 
					 | 
					        start_time = time.time() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.cached_train_data=None | 
					 | 
					 | 
					        self.cached_train_data=None | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -948,7 +939,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self._tracing_model_by_script() | 
					 | 
					 | 
					        self._tracing_model_by_script() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #self._tracing_model() | 
					 | 
					 | 
					        #self._tracing_model() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') | 
					 | 
					 | 
					        logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        logger.info(f'Best validation loss: {best_val_loss:.6f}') | 
					 | 
					 | 
					        logger.info(f'Best validation loss: {self.best_loss:.6f}') | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        #self.test_load() | 
					 | 
					 | 
					        #self.test_load() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def test_load(self): | 
					 | 
					 | 
					    def test_load(self): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -999,7 +990,10 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self.model_name  | 
					 | 
					 | 
					            self.model_name  | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        os.makedirs(checkpoint_dir, exist_ok=True) | 
					 | 
					 | 
					        os.makedirs(checkpoint_dir, exist_ok=True) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if epoch >= 0: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") | 
					 | 
					 | 
					            checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            checkpoint_path = os.path.join(checkpoint_dir, f"best.pth") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 只保存状态字典 | 
					 | 
					 | 
					        # 只保存状态字典 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        torch.save({ | 
					 | 
					 | 
					        torch.save({ | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |