Browse Source

stage2 训练改batch

final
mckay 1 month ago
parent
commit
8ae855f411
  1. 74
      brep2sdf/train.py

74
brep2sdf/train.py

@ -443,37 +443,63 @@ class Trainer:
# 累积损失 # 累积损失
losses.append(loss) losses.append(loss)
if epoch % 1 == 0: if epoch % 100 == 0:
loss_detailss.append(loss_details) loss_detailss.append(loss_details)
if len(losses) % 30 == 0:
# 多个损失平均后反向传播
# 多个损失平均后反向传播 loss_tensor = torch.stack(losses)
loss_tensor = torch.stack(losses) mean_loss = (loss_tensor * weights).sum()
mean_loss = (loss_tensor * weights).sum() mean_loss.backward()
mean_loss.backward()
# 更新参数 # 更新参数
self.scheduler.optimizer.step() self.scheduler.optimizer.step()
self.scheduler.step(mean_loss, epoch) self.scheduler.step(mean_loss, epoch)
# 清空梯度 # 清空梯度
self.scheduler.optimizer.zero_grad() self.scheduler.optimizer.zero_grad()
torch.cuda.empty_cache()
# 如果你想查看详细的损失信息,可以在这里添加日志记录
if epoch % 100 == 0:
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t'
f'Loss: {loss:.6f}')
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5]
# 对每个子项取加权平均(如果需要 weights)
weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum()
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
losses = []
loss_detailss = []
if len(losses) > 0:
# 多个损失平均后反向传播
loss_tensor = torch.stack(losses)
mean_loss = (loss_tensor * weights).sum()
mean_loss.backward()
# 更新参数
self.scheduler.optimizer.step()
self.scheduler.step(mean_loss, epoch)
# 清空梯度
self.scheduler.optimizer.zero_grad()
torch.cuda.empty_cache()
# 如果你想查看详细的损失信息,可以在这里添加日志记录
if epoch % 100 == 0:
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t'
f'Loss: {loss:.6f}')
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5]
# 清理缓存 # 对每个子项取加权平均(如果需要 weights)
torch.cuda.empty_cache() weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum()
# 如果你想查看详细的损失信息,可以在这里添加日志记录
if epoch % 1 == 0:
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t'
f'Loss: {loss:.6f}')
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5]
# 对每个子项取加权平均(如果需要 weights) subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}")

Loading…
Cancel
Save