| 
						
						
							
								
							
						
						
					 | 
					@ -579,10 +579,10 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def train_epoch_stage3(self, epoch: int) -> float: | 
					 | 
					 | 
					    def train_epoch_stage3(self, epoch: int) -> float: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # --- 1. 检查输入数据 --- | 
					 | 
					 | 
					        # --- 1. 检查输入数据 --- | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) | 
					 | 
					 | 
					        # 注意:假设 self.train_surf_ncs 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        # 并且 SDF 值总是在最后一列 | 
					 | 
					 | 
					        # 并且 SDF 值总是在最后一列 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if self.sdf_data is None: | 
					 | 
					 | 
					        if self.train_surf_ncs is None: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					             logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") | 
					 | 
					 | 
					             logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					             return float('inf') | 
					 | 
					 | 
					             return float('inf') | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.model.train() | 
					 | 
					 | 
					        self.model.train() | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -592,9 +592,9 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 数据处理 | 
					 | 
					 | 
					        # 数据处理 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # manfld | 
					 | 
					 | 
					        # manfld | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        _mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 | 
					 | 
					 | 
					        _mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        _normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 | 
					 | 
					 | 
					        _normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        _gt_sdf = self.sdf_data[:, -1].clone().detach()  # 取最后一列作为SDF真值 | 
					 | 
					 | 
					        _gt_sdf = self.train_surf_ncs[:, -1].clone().detach()  # 取最后一列作为SDF真值 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 检查是否需要重新计算缓存 | 
					 | 
					 | 
					        # 检查是否需要重新计算缓存 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if epoch % 10 == 1 or self.cached_train_data is None: | 
					 | 
					 | 
					        if epoch % 10 == 1 or self.cached_train_data is None: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -628,7 +628,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 将数据分成多个batch | 
					 | 
					 | 
					        # 将数据分成多个batch | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_points = self.sdf_data.shape[0] | 
					 | 
					 | 
					        num_points = self.train_surf_ncs.shape[0] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        num_batches = (num_points + batch_size - 1) // batch_size | 
					 | 
					 | 
					        num_batches = (num_points + batch_size - 1) // batch_size | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        for batch_idx in range(num_batches): | 
					 | 
					 | 
					        for batch_idx in range(num_batches): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -894,7 +894,7 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info(f"Loaded model from {args.resume_checkpoint_path}") | 
					 | 
					 | 
					            logger.info(f"Loaded model from {args.resume_checkpoint_path}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # stage1 | 
					 | 
					 | 
					        # stage1 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.model.freeze_stage2() | 
					 | 
					 | 
					        self.model.freeze_stage1() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): | 
					 | 
					 | 
					        for epoch in range(start_epoch, self.config.train.num_epochs1 + 1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 训练一个epoch | 
					 | 
					 | 
					            # 训练一个epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            train_loss = self.train_epoch_stage1(epoch) | 
					 | 
					 | 
					            train_loss = self.train_epoch_stage1(epoch) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -921,8 +921,8 @@ class Trainer: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        #stage 3 | 
					 | 
					 | 
					        #stage 3 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.scheduler.reset() | 
					 | 
					 | 
					        self.scheduler.reset() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        #self.model.freeze_stage2() | 
					 | 
					 | 
					        self.model.freeze_stage2() | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.model.unfreeze() | 
					 | 
					 | 
					        #self.model.unfreeze() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): | 
					 | 
					 | 
					        for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 训练一个epoch | 
					 | 
					 | 
					            # 训练一个epoch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            train_loss = self.train_epoch_stage3(epoch) | 
					 | 
					 | 
					            train_loss = self.train_epoch_stage3(epoch) | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |