Browse Source

fix :train script

main
王琛涵 4 months ago
parent
commit
f3e3886f0e
  1. 88
      brep2sdf/train.py

88
brep2sdf/train.py

@ -45,6 +45,8 @@ class Trainer:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}")
# 初始化数据集
@ -69,13 +71,15 @@ class Trainer:
self.train_dataset,
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config.train.num_workers
num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.train.num_workers
num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
# 初始化模型
@ -105,28 +109,34 @@ class Trainer:
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
# 获取数据
surf_z = batch['surf_z'].to(self.device)
edge_z = batch['edge_z'].to(self.device)
surf_p = batch['surf_p'].to(self.device)
edge_p = batch['edge_p'].to(self.device)
vert_p = batch['vert_p'].to(self.device)
query_points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 清空梯度
self.optimizer.zero_grad()
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3] # 只使用点坐标,不包括SDF值
)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
query_points,
sdf[:, 3], # 使用SDF值
sdf[:, :3], # 使用点坐标
grad_weight=self.config.train.grad_weight
)
# 反向传播
self.optimizer.zero_grad()
# 反向传播和优化
loss.backward()
# 梯度裁剪
@ -134,22 +144,21 @@ class Trainer:
self.model.parameters(),
self.config.train.max_grad_norm
)
self.optimizer.step()
total_loss += loss.item()
# 打印训练进度
# 记录训练进度
if (batch_idx + 1) % self.config.log.log_interval == 0:
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t'
f'Loss: {loss.item():.6f}')
# 记录到wandb
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
'epoch': epoch
})
if self.config.log.use_wandb:
wandb.log({
'batch_loss': loss.item(),
'batch': batch_idx,
'epoch': epoch
})
avg_loss = total_loss / len(self.train_loader)
return avg_loss
@ -160,23 +169,27 @@ class Trainer:
with torch.no_grad():
for batch in self.val_loader:
# 获取数据
surf_z = batch['surf_z'].to(self.device)
edge_z = batch['edge_z'].to(self.device)
surf_p = batch['surf_p'].to(self.device)
edge_p = batch['edge_p'].to(self.device)
vert_p = batch['vert_p'].to(self.device)
query_points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 获取数据并移动到设备
surf_ncs = batch['surf_ncs'].to(self.device)
edge_ncs = batch['edge_ncs'].to(self.device)
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3]
)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
query_points,
sdf[:, 3],
sdf[:, :3],
grad_weight=self.config.train.grad_weight
)
@ -195,11 +208,12 @@ class Trainer:
def train(self):
best_val_loss = float('inf')
logger.info("Starting training...")
val_loss = float('inf') # 初始化val_loss
for epoch in range(1, self.config.train.num_epochs + 1):
train_loss = self.train_epoch(epoch)
# 定期验证
# 定期验证和保存
if epoch % self.config.train.val_freq == 0:
val_loss = self.validate(epoch)

Loading…
Cancel
Save