From f3e3886f0e092c372172ca853e19fedc15dafa84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Tue, 19 Nov 2024 02:24:28 +0800 Subject: [PATCH] fix :train script --- brep2sdf/train.py | 88 +++++++++++++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 09854c7..ffcb00f 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -45,6 +45,8 @@ class Trainer: def __init__(self, config): self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory + logger.info(f"Using device: {self.device}") # 初始化数据集 @@ -69,13 +71,15 @@ class Trainer: self.train_dataset, batch_size=config.train.batch_size, shuffle=True, - num_workers=config.train.num_workers + num_workers=config.train.num_workers, + pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 ) self.val_loader = DataLoader( self.val_dataset, batch_size=config.train.batch_size, shuffle=False, - num_workers=config.train.num_workers + num_workers=config.train.num_workers, + pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 ) # 初始化模型 @@ -105,28 +109,34 @@ class Trainer: total_loss = 0 for batch_idx, batch in enumerate(self.train_loader): - # 获取数据 - surf_z = batch['surf_z'].to(self.device) - edge_z = batch['edge_z'].to(self.device) - surf_p = batch['surf_p'].to(self.device) - edge_p = batch['edge_p'].to(self.device) - vert_p = batch['vert_p'].to(self.device) - query_points = batch['points'].to(self.device) - gt_sdf = batch['sdf'].to(self.device) + # 清空梯度 + self.optimizer.zero_grad() + + # 获取数据并移动到设备 + surf_ncs = batch['surf_ncs'].to(self.device) + edge_ncs = batch['edge_ncs'].to(self.device) + surf_pos = batch['surf_pos'].to(self.device) + edge_pos = batch['edge_pos'].to(self.device) + vertex_pos = batch['vertex_pos'].to(self.device) + sdf = batch['sdf'].to(self.device) # 前向传播 - pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) + pred_sdf = self.model( + surf_ncs, edge_ncs, + surf_pos, edge_pos, + vertex_pos, + sdf[:, :3] # 只使用点坐标,不包括SDF值 + ) # 计算损失 loss = sdf_loss( pred_sdf, - gt_sdf, - query_points, + sdf[:, 3], # 使用SDF值 + sdf[:, :3], # 使用点坐标 grad_weight=self.config.train.grad_weight ) - # 反向传播 - self.optimizer.zero_grad() + # 反向传播和优化 loss.backward() # 梯度裁剪 @@ -134,22 +144,21 @@ class Trainer: self.model.parameters(), self.config.train.max_grad_norm ) - self.optimizer.step() + total_loss += loss.item() - # 打印训练进度 + # 记录训练进度 if (batch_idx + 1) % self.config.log.log_interval == 0: logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' f'Loss: {loss.item():.6f}') - - # 记录到wandb - if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: - wandb.log({ - 'batch_loss': loss.item(), - 'batch': batch_idx, - 'epoch': epoch - }) + + if self.config.log.use_wandb: + wandb.log({ + 'batch_loss': loss.item(), + 'batch': batch_idx, + 'epoch': epoch + }) avg_loss = total_loss / len(self.train_loader) return avg_loss @@ -160,23 +169,27 @@ class Trainer: with torch.no_grad(): for batch in self.val_loader: - # 获取数据 - surf_z = batch['surf_z'].to(self.device) - edge_z = batch['edge_z'].to(self.device) - surf_p = batch['surf_p'].to(self.device) - edge_p = batch['edge_p'].to(self.device) - vert_p = batch['vert_p'].to(self.device) - query_points = batch['points'].to(self.device) - gt_sdf = batch['sdf'].to(self.device) + # 获取数据并移动到设备 + surf_ncs = batch['surf_ncs'].to(self.device) + edge_ncs = batch['edge_ncs'].to(self.device) + surf_pos = batch['surf_pos'].to(self.device) + edge_pos = batch['edge_pos'].to(self.device) + vertex_pos = batch['vertex_pos'].to(self.device) + sdf = batch['sdf'].to(self.device) # 前向传播 - pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) + pred_sdf = self.model( + surf_ncs, edge_ncs, + surf_pos, edge_pos, + vertex_pos, + sdf[:, :3] + ) # 计算损失 loss = sdf_loss( pred_sdf, - gt_sdf, - query_points, + sdf[:, 3], + sdf[:, :3], grad_weight=self.config.train.grad_weight ) @@ -195,11 +208,12 @@ class Trainer: def train(self): best_val_loss = float('inf') logger.info("Starting training...") + val_loss = float('inf') # 初始化val_loss for epoch in range(1, self.config.train.num_epochs + 1): train_loss = self.train_epoch(epoch) - # 定期验证 + # 定期验证和保存 if epoch % self.config.train.val_freq == 0: val_loss = self.validate(epoch)