|  |  | @ -45,6 +45,8 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, config): | 
			
		
	
		
			
				
					|  |  |  |         self.config = config | 
			
		
	
		
			
				
					|  |  |  |         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
			
		
	
		
			
				
					|  |  |  |         use_pin_memory = self.device.type == 'cuda'  # 根据设备类型决定是否使用pin_memory | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         logger.info(f"Using device: {self.device}") | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         # 初始化数据集 | 
			
		
	
	
		
			
				
					|  |  | @ -69,13 +71,15 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |             self.train_dataset, | 
			
		
	
		
			
				
					|  |  |  |             batch_size=config.train.batch_size, | 
			
		
	
		
			
				
					|  |  |  |             shuffle=True, | 
			
		
	
		
			
				
					|  |  |  |             num_workers=config.train.num_workers | 
			
		
	
		
			
				
					|  |  |  |             num_workers=config.train.num_workers, | 
			
		
	
		
			
				
					|  |  |  |             pin_memory=use_pin_memory  # 根据设备类型设置,是否将数据固定在内存中 | 
			
		
	
		
			
				
					|  |  |  |         ) | 
			
		
	
		
			
				
					|  |  |  |         self.val_loader = DataLoader( | 
			
		
	
		
			
				
					|  |  |  |             self.val_dataset, | 
			
		
	
		
			
				
					|  |  |  |             batch_size=config.train.batch_size, | 
			
		
	
		
			
				
					|  |  |  |             shuffle=False, | 
			
		
	
		
			
				
					|  |  |  |             num_workers=config.train.num_workers | 
			
		
	
		
			
				
					|  |  |  |             num_workers=config.train.num_workers, | 
			
		
	
		
			
				
					|  |  |  |             pin_memory=use_pin_memory  # 根据设备类型设置,是否将数据固定在内存中 | 
			
		
	
		
			
				
					|  |  |  |         ) | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         # 初始化模型 | 
			
		
	
	
		
			
				
					|  |  | @ -105,28 +109,34 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |         total_loss = 0 | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         for batch_idx, batch in enumerate(self.train_loader): | 
			
		
	
		
			
				
					|  |  |  |             # 获取数据 | 
			
		
	
		
			
				
					|  |  |  |             surf_z = batch['surf_z'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             edge_z = batch['edge_z'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             surf_p = batch['surf_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             edge_p = batch['edge_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             vert_p = batch['vert_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             query_points = batch['points'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             gt_sdf = batch['sdf'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             # 清空梯度 | 
			
		
	
		
			
				
					|  |  |  |             self.optimizer.zero_grad() | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 获取数据并移动到设备 | 
			
		
	
		
			
				
					|  |  |  |             surf_ncs = batch['surf_ncs'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             edge_ncs = batch['edge_ncs'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             surf_pos = batch['surf_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             edge_pos = batch['edge_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             vertex_pos = batch['vertex_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |             sdf = batch['sdf'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 前向传播 | 
			
		
	
		
			
				
					|  |  |  |             pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) | 
			
		
	
		
			
				
					|  |  |  |             pred_sdf = self.model( | 
			
		
	
		
			
				
					|  |  |  |                 surf_ncs, edge_ncs, | 
			
		
	
		
			
				
					|  |  |  |                 surf_pos, edge_pos, | 
			
		
	
		
			
				
					|  |  |  |                 vertex_pos, | 
			
		
	
		
			
				
					|  |  |  |                 sdf[:, :3]  # 只使用点坐标,不包括SDF值 | 
			
		
	
		
			
				
					|  |  |  |             ) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 计算损失 | 
			
		
	
		
			
				
					|  |  |  |             loss = sdf_loss( | 
			
		
	
		
			
				
					|  |  |  |                 pred_sdf, | 
			
		
	
		
			
				
					|  |  |  |                 gt_sdf, | 
			
		
	
		
			
				
					|  |  |  |                 query_points, | 
			
		
	
		
			
				
					|  |  |  |                 sdf[:, 3],  # 使用SDF值 | 
			
		
	
		
			
				
					|  |  |  |                 sdf[:, :3],  # 使用点坐标 | 
			
		
	
		
			
				
					|  |  |  |                 grad_weight=self.config.train.grad_weight | 
			
		
	
		
			
				
					|  |  |  |             ) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 反向传播 | 
			
		
	
		
			
				
					|  |  |  |             self.optimizer.zero_grad() | 
			
		
	
		
			
				
					|  |  |  |             # 反向传播和优化 | 
			
		
	
		
			
				
					|  |  |  |             loss.backward() | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 梯度裁剪 | 
			
		
	
	
		
			
				
					|  |  | @ -134,22 +144,21 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |                 self.model.parameters(),  | 
			
		
	
		
			
				
					|  |  |  |                 self.config.train.max_grad_norm | 
			
		
	
		
			
				
					|  |  |  |             ) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             self.optimizer.step() | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             total_loss += loss.item() | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 打印训练进度 | 
			
		
	
		
			
				
					|  |  |  |             # 记录训练进度 | 
			
		
	
		
			
				
					|  |  |  |             if (batch_idx + 1) % self.config.log.log_interval == 0: | 
			
		
	
		
			
				
					|  |  |  |                 logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' | 
			
		
	
		
			
				
					|  |  |  |                           f'Loss: {loss.item():.6f}') | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 记录到wandb | 
			
		
	
		
			
				
					|  |  |  |             if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: | 
			
		
	
		
			
				
					|  |  |  |                 wandb.log({ | 
			
		
	
		
			
				
					|  |  |  |                     'batch_loss': loss.item(), | 
			
		
	
		
			
				
					|  |  |  |                     'batch': batch_idx, | 
			
		
	
		
			
				
					|  |  |  |                     'epoch': epoch | 
			
		
	
		
			
				
					|  |  |  |                 }) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |                 if self.config.log.use_wandb: | 
			
		
	
		
			
				
					|  |  |  |                     wandb.log({ | 
			
		
	
		
			
				
					|  |  |  |                         'batch_loss': loss.item(), | 
			
		
	
		
			
				
					|  |  |  |                         'batch': batch_idx, | 
			
		
	
		
			
				
					|  |  |  |                         'epoch': epoch | 
			
		
	
		
			
				
					|  |  |  |                     }) | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         avg_loss = total_loss / len(self.train_loader) | 
			
		
	
		
			
				
					|  |  |  |         return avg_loss | 
			
		
	
	
		
			
				
					|  |  | @ -160,23 +169,27 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         with torch.no_grad(): | 
			
		
	
		
			
				
					|  |  |  |             for batch in self.val_loader: | 
			
		
	
		
			
				
					|  |  |  |                 # 获取数据 | 
			
		
	
		
			
				
					|  |  |  |                 surf_z = batch['surf_z'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 edge_z = batch['edge_z'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 surf_p = batch['surf_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 edge_p = batch['edge_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 vert_p = batch['vert_p'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 query_points = batch['points'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 gt_sdf = batch['sdf'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 # 获取数据并移动到设备 | 
			
		
	
		
			
				
					|  |  |  |                 surf_ncs = batch['surf_ncs'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 edge_ncs = batch['edge_ncs'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 surf_pos = batch['surf_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 edge_pos = batch['edge_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 vertex_pos = batch['vertex_pos'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                 sdf = batch['sdf'].to(self.device) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |                 # 前向传播 | 
			
		
	
		
			
				
					|  |  |  |                 pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) | 
			
		
	
		
			
				
					|  |  |  |                 pred_sdf = self.model( | 
			
		
	
		
			
				
					|  |  |  |                     surf_ncs, edge_ncs, | 
			
		
	
		
			
				
					|  |  |  |                     surf_pos, edge_pos, | 
			
		
	
		
			
				
					|  |  |  |                     vertex_pos, | 
			
		
	
		
			
				
					|  |  |  |                     sdf[:, :3] | 
			
		
	
		
			
				
					|  |  |  |                 ) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |                 # 计算损失 | 
			
		
	
		
			
				
					|  |  |  |                 loss = sdf_loss( | 
			
		
	
		
			
				
					|  |  |  |                     pred_sdf, | 
			
		
	
		
			
				
					|  |  |  |                     gt_sdf, | 
			
		
	
		
			
				
					|  |  |  |                     query_points, | 
			
		
	
		
			
				
					|  |  |  |                     sdf[:, 3], | 
			
		
	
		
			
				
					|  |  |  |                     sdf[:, :3], | 
			
		
	
		
			
				
					|  |  |  |                     grad_weight=self.config.train.grad_weight | 
			
		
	
		
			
				
					|  |  |  |                 ) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
	
		
			
				
					|  |  | @ -195,11 +208,12 @@ class Trainer: | 
			
		
	
		
			
				
					|  |  |  |     def train(self): | 
			
		
	
		
			
				
					|  |  |  |         best_val_loss = float('inf') | 
			
		
	
		
			
				
					|  |  |  |         logger.info("Starting training...") | 
			
		
	
		
			
				
					|  |  |  |         val_loss = float('inf')  # 初始化val_loss | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         for epoch in range(1, self.config.train.num_epochs + 1): | 
			
		
	
		
			
				
					|  |  |  |             train_loss = self.train_epoch(epoch) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             # 定期验证 | 
			
		
	
		
			
				
					|  |  |  |             # 定期验证和保存 | 
			
		
	
		
			
				
					|  |  |  |             if epoch % self.config.train.val_freq == 0: | 
			
		
	
		
			
				
					|  |  |  |                 val_loss = self.validate(epoch) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
	
		
			
				
					|  |  | 
 |