|
@ -443,11 +443,38 @@ class Trainer: |
|
|
|
|
|
|
|
|
# 累积损失 |
|
|
# 累积损失 |
|
|
losses.append(loss) |
|
|
losses.append(loss) |
|
|
if epoch % 1 == 0: |
|
|
if epoch % 100 == 0: |
|
|
loss_detailss.append(loss_details) |
|
|
loss_detailss.append(loss_details) |
|
|
|
|
|
|
|
|
|
|
|
if len(losses) % 30 == 0: |
|
|
|
|
|
# 多个损失平均后反向传播 |
|
|
|
|
|
loss_tensor = torch.stack(losses) |
|
|
|
|
|
mean_loss = (loss_tensor * weights).sum() |
|
|
|
|
|
mean_loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
# 更新参数 |
|
|
|
|
|
self.scheduler.optimizer.step() |
|
|
|
|
|
self.scheduler.step(mean_loss, epoch) |
|
|
|
|
|
|
|
|
|
|
|
# 清空梯度 |
|
|
|
|
|
self.scheduler.optimizer.zero_grad() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
# 如果你想查看详细的损失信息,可以在这里添加日志记录 |
|
|
|
|
|
if epoch % 100 == 0: |
|
|
|
|
|
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' |
|
|
|
|
|
f'Loss: {loss:.6f}') |
|
|
|
|
|
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] |
|
|
|
|
|
|
|
|
|
|
|
# 对每个子项取加权平均(如果需要 weights) |
|
|
|
|
|
weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() |
|
|
|
|
|
|
|
|
|
|
|
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] |
|
|
|
|
|
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) |
|
|
|
|
|
losses = [] |
|
|
|
|
|
loss_detailss = [] |
|
|
|
|
|
|
|
|
|
|
|
if len(losses) > 0: |
|
|
# 多个损失平均后反向传播 |
|
|
# 多个损失平均后反向传播 |
|
|
loss_tensor = torch.stack(losses) |
|
|
loss_tensor = torch.stack(losses) |
|
|
mean_loss = (loss_tensor * weights).sum() |
|
|
mean_loss = (loss_tensor * weights).sum() |
|
@ -459,12 +486,10 @@ class Trainer: |
|
|
|
|
|
|
|
|
# 清空梯度 |
|
|
# 清空梯度 |
|
|
self.scheduler.optimizer.zero_grad() |
|
|
self.scheduler.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
# 清理缓存 |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
# 如果你想查看详细的损失信息,可以在这里添加日志记录 |
|
|
# 如果你想查看详细的损失信息,可以在这里添加日志记录 |
|
|
if epoch % 1 == 0: |
|
|
if epoch % 100 == 0: |
|
|
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' |
|
|
logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' |
|
|
f'Loss: {loss:.6f}') |
|
|
f'Loss: {loss:.6f}') |
|
|
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] |
|
|
loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] |
|
@ -475,6 +500,7 @@ class Trainer: |
|
|
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] |
|
|
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] |
|
|
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) |
|
|
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
avg_loss = sum(losses) / len(losses) |
|
|
avg_loss = sum(losses) / len(losses) |
|
|
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") |
|
|
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") |
|
|
|
|
|
|
|
|