Browse Source

fix :train script

main
王琛涵 7 months ago
parent
commit
f3e3886f0e
  1. 76
      brep2sdf/train.py

76
brep2sdf/train.py

@ -45,6 +45,8 @@ class Trainer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
# 初始化数据集 # 初始化数据集
@ -69,13 +71,15 @@ class Trainer:
self.train_dataset, self.train_dataset,
batch_size=config.train.batch_size, batch_size=config.train.batch_size,
shuffle=True, shuffle=True,
num_workers=config.train.num_workers num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
) )
self.val_loader = DataLoader( self.val_loader = DataLoader(
self.val_dataset, self.val_dataset,
batch_size=config.train.batch_size, batch_size=config.train.batch_size,
shuffle=False, shuffle=False,
num_workers=config.train.num_workers num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
) )
# 初始化模型 # 初始化模型
@ -105,28 +109,34 @@ class Trainer:
total_loss = 0 total_loss = 0
for batch_idx, batch in enumerate(self.train_loader): for batch_idx, batch in enumerate(self.train_loader):
# 获取数据 # 清空梯度
surf_z = batch['surf_z'].to(self.device) self.optimizer.zero_grad()
edge_z = batch['edge_z'].to(self.device)
surf_p = batch['surf_p'].to(self.device) # 获取数据并移动到设备
edge_p = batch['edge_p'].to(self.device) surf_ncs = batch['surf_ncs'].to(self.device)
vert_p = batch['vert_p'].to(self.device) edge_ncs = batch['edge_ncs'].to(self.device)
query_points = batch['points'].to(self.device) surf_pos = batch['surf_pos'].to(self.device)
gt_sdf = batch['sdf'].to(self.device) edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
sdf = batch['sdf'].to(self.device)
# 前向传播 # 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3] # 只使用点坐标,不包括SDF值
)
# 计算损失 # 计算损失
loss = sdf_loss( loss = sdf_loss(
pred_sdf, pred_sdf,
gt_sdf, sdf[:, 3], # 使用SDF值
query_points, sdf[:, :3], # 使用点坐标
grad_weight=self.config.train.grad_weight grad_weight=self.config.train.grad_weight
) )
# 反向传播 # 反向传播和优化
self.optimizer.zero_grad()
loss.backward() loss.backward()
# 梯度裁剪 # 梯度裁剪
@ -134,17 +144,16 @@ class Trainer:
self.model.parameters(), self.model.parameters(),
self.config.train.max_grad_norm self.config.train.max_grad_norm
) )
self.optimizer.step() self.optimizer.step()
total_loss += loss.item() total_loss += loss.item()
# 打印训练进度 # 记录训练进度
if (batch_idx + 1) % self.config.log.log_interval == 0: if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}') f'Loss: {loss.item():.6f}')
# 记录到wandb if self.config.log.use_wandb:
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0:
wandb.log({ wandb.log({
'batch_loss': loss.item(), 'batch_loss': loss.item(),
'batch': batch_idx, 'batch': batch_idx,
@ -160,23 +169,27 @@ class Trainer:
with torch.no_grad(): with torch.no_grad():
for batch in self.val_loader: for batch in self.val_loader:
# 获取数据 # 获取数据并移动到设备
surf_z = batch['surf_z'].to(self.device) surf_ncs = batch['surf_ncs'].to(self.device)
edge_z = batch['edge_z'].to(self.device) edge_ncs = batch['edge_ncs'].to(self.device)
surf_p = batch['surf_p'].to(self.device) surf_pos = batch['surf_pos'].to(self.device)
edge_p = batch['edge_p'].to(self.device) edge_pos = batch['edge_pos'].to(self.device)
vert_p = batch['vert_p'].to(self.device) vertex_pos = batch['vertex_pos'].to(self.device)
query_points = batch['points'].to(self.device) sdf = batch['sdf'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播 # 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3]
)
# 计算损失 # 计算损失
loss = sdf_loss( loss = sdf_loss(
pred_sdf, pred_sdf,
gt_sdf, sdf[:, 3],
query_points, sdf[:, :3],
grad_weight=self.config.train.grad_weight grad_weight=self.config.train.grad_weight
) )
@ -195,11 +208,12 @@ class Trainer:
def train(self): def train(self):
best_val_loss = float('inf') best_val_loss = float('inf')
logger.info("Starting training...") logger.info("Starting training...")
val_loss = float('inf') # 初始化val_loss
for epoch in range(1, self.config.train.num_epochs + 1): for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch) train_loss = self.train_epoch(epoch)
# 定期验证 # 定期验证和保存
if epoch % self.config.train.val_freq == 0: if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch) val_loss = self.validate(epoch)

Loading…
Cancel
Save