|
|
@ -45,6 +45,8 @@ class Trainer: |
|
|
|
def __init__(self, config): |
|
|
|
self.config = config |
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory |
|
|
|
|
|
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
# 初始化数据集 |
|
|
@ -69,13 +71,15 @@ class Trainer: |
|
|
|
self.train_dataset, |
|
|
|
batch_size=config.train.batch_size, |
|
|
|
shuffle=True, |
|
|
|
num_workers=config.train.num_workers |
|
|
|
num_workers=config.train.num_workers, |
|
|
|
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 |
|
|
|
) |
|
|
|
self.val_loader = DataLoader( |
|
|
|
self.val_dataset, |
|
|
|
batch_size=config.train.batch_size, |
|
|
|
shuffle=False, |
|
|
|
num_workers=config.train.num_workers |
|
|
|
num_workers=config.train.num_workers, |
|
|
|
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 |
|
|
|
) |
|
|
|
|
|
|
|
# 初始化模型 |
|
|
@ -105,28 +109,34 @@ class Trainer: |
|
|
|
total_loss = 0 |
|
|
|
|
|
|
|
for batch_idx, batch in enumerate(self.train_loader): |
|
|
|
# 获取数据 |
|
|
|
surf_z = batch['surf_z'].to(self.device) |
|
|
|
edge_z = batch['edge_z'].to(self.device) |
|
|
|
surf_p = batch['surf_p'].to(self.device) |
|
|
|
edge_p = batch['edge_p'].to(self.device) |
|
|
|
vert_p = batch['vert_p'].to(self.device) |
|
|
|
query_points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
# 清空梯度 |
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
# 获取数据并移动到设备 |
|
|
|
surf_ncs = batch['surf_ncs'].to(self.device) |
|
|
|
edge_ncs = batch['edge_ncs'].to(self.device) |
|
|
|
surf_pos = batch['surf_pos'].to(self.device) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device) |
|
|
|
sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|
|
|
pred_sdf = self.model( |
|
|
|
surf_ncs, edge_ncs, |
|
|
|
surf_pos, edge_pos, |
|
|
|
vertex_pos, |
|
|
|
sdf[:, :3] # 只使用点坐标,不包括SDF值 |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|
loss = sdf_loss( |
|
|
|
pred_sdf, |
|
|
|
gt_sdf, |
|
|
|
query_points, |
|
|
|
sdf[:, 3], # 使用SDF值 |
|
|
|
sdf[:, :3], # 使用点坐标 |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
|
# 反向传播 |
|
|
|
self.optimizer.zero_grad() |
|
|
|
# 反向传播和优化 |
|
|
|
loss.backward() |
|
|
|
|
|
|
|
# 梯度裁剪 |
|
|
@ -134,22 +144,21 @@ class Trainer: |
|
|
|
self.model.parameters(), |
|
|
|
self.config.train.max_grad_norm |
|
|
|
) |
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
# 打印训练进度 |
|
|
|
# 记录训练进度 |
|
|
|
if (batch_idx + 1) % self.config.log.log_interval == 0: |
|
|
|
logger.info(f'Train Epoch: {epoch} [{batch_idx+1}/{len(self.train_loader)}]\t' |
|
|
|
f'Loss: {loss.item():.6f}') |
|
|
|
|
|
|
|
# 记录到wandb |
|
|
|
if self.config.log.use_wandb and (batch_idx + 1) % self.config.log.log_interval == 0: |
|
|
|
wandb.log({ |
|
|
|
'batch_loss': loss.item(), |
|
|
|
'batch': batch_idx, |
|
|
|
'epoch': epoch |
|
|
|
}) |
|
|
|
|
|
|
|
if self.config.log.use_wandb: |
|
|
|
wandb.log({ |
|
|
|
'batch_loss': loss.item(), |
|
|
|
'batch': batch_idx, |
|
|
|
'epoch': epoch |
|
|
|
}) |
|
|
|
|
|
|
|
avg_loss = total_loss / len(self.train_loader) |
|
|
|
return avg_loss |
|
|
@ -160,23 +169,27 @@ class Trainer: |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for batch in self.val_loader: |
|
|
|
# 获取数据 |
|
|
|
surf_z = batch['surf_z'].to(self.device) |
|
|
|
edge_z = batch['edge_z'].to(self.device) |
|
|
|
surf_p = batch['surf_p'].to(self.device) |
|
|
|
edge_p = batch['edge_p'].to(self.device) |
|
|
|
vert_p = batch['vert_p'].to(self.device) |
|
|
|
query_points = batch['points'].to(self.device) |
|
|
|
gt_sdf = batch['sdf'].to(self.device) |
|
|
|
# 获取数据并移动到设备 |
|
|
|
surf_ncs = batch['surf_ncs'].to(self.device) |
|
|
|
edge_ncs = batch['edge_ncs'].to(self.device) |
|
|
|
surf_pos = batch['surf_pos'].to(self.device) |
|
|
|
edge_pos = batch['edge_pos'].to(self.device) |
|
|
|
vertex_pos = batch['vertex_pos'].to(self.device) |
|
|
|
sdf = batch['sdf'].to(self.device) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
pred_sdf = self.model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|
|
|
pred_sdf = self.model( |
|
|
|
surf_ncs, edge_ncs, |
|
|
|
surf_pos, edge_pos, |
|
|
|
vertex_pos, |
|
|
|
sdf[:, :3] |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|
loss = sdf_loss( |
|
|
|
pred_sdf, |
|
|
|
gt_sdf, |
|
|
|
query_points, |
|
|
|
sdf[:, 3], |
|
|
|
sdf[:, :3], |
|
|
|
grad_weight=self.config.train.grad_weight |
|
|
|
) |
|
|
|
|
|
|
@ -195,11 +208,12 @@ class Trainer: |
|
|
|
def train(self): |
|
|
|
best_val_loss = float('inf') |
|
|
|
logger.info("Starting training...") |
|
|
|
val_loss = float('inf') # 初始化val_loss |
|
|
|
|
|
|
|
for epoch in range(1, self.config.train.num_epochs + 1): |
|
|
|
train_loss = self.train_epoch(epoch) |
|
|
|
|
|
|
|
# 定期验证 |
|
|
|
# 定期验证和保存 |
|
|
|
if epoch % self.config.train.val_freq == 0: |
|
|
|
val_loss = self.validate(epoch) |
|
|
|
|
|
|
|