diff --git a/brep2sdf/train.py b/brep2sdf/train.py index a8c6177..9a9c267 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -443,37 +443,63 @@ class Trainer: # 累积损失 losses.append(loss) - if epoch % 1 == 0: + if epoch % 100 == 0: loss_detailss.append(loss_details) - - - # 多个损失平均后反向传播 - loss_tensor = torch.stack(losses) - mean_loss = (loss_tensor * weights).sum() - mean_loss.backward() + if len(losses) % 30 == 0: + # 多个损失平均后反向传播 + loss_tensor = torch.stack(losses) + mean_loss = (loss_tensor * weights).sum() + mean_loss.backward() - # 更新参数 - self.scheduler.optimizer.step() - self.scheduler.step(mean_loss, epoch) + # 更新参数 + self.scheduler.optimizer.step() + self.scheduler.step(mean_loss, epoch) - # 清空梯度 - self.scheduler.optimizer.zero_grad() + # 清空梯度 + self.scheduler.optimizer.zero_grad() + torch.cuda.empty_cache() + + # 如果你想查看详细的损失信息,可以在这里添加日志记录 + if epoch % 100 == 0: + logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' + f'Loss: {loss:.6f}') + loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] + + # 对每个子项取加权平均(如果需要 weights) + weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() + + subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] + logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) + losses = [] + loss_detailss = [] + + if len(losses) > 0: + # 多个损失平均后反向传播 + loss_tensor = torch.stack(losses) + mean_loss = (loss_tensor * weights).sum() + mean_loss.backward() + + # 更新参数 + self.scheduler.optimizer.step() + self.scheduler.step(mean_loss, epoch) + + # 清空梯度 + self.scheduler.optimizer.zero_grad() + torch.cuda.empty_cache() + + # 如果你想查看详细的损失信息,可以在这里添加日志记录 + if epoch % 100 == 0: + logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' + f'Loss: {loss:.6f}') + loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] - # 清理缓存 - torch.cuda.empty_cache() - - # 如果你想查看详细的损失信息,可以在这里添加日志记录 - if epoch % 1 == 0: - logger.info(f'Train [Stage 2] Epoch: {epoch:4d}\t' - f'Loss: {loss:.6f}') - loss_details_tensor = torch.stack(loss_detailss) # shape: [num_patches, 5] + # 对每个子项取加权平均(如果需要 weights) + weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() - # 对每个子项取加权平均(如果需要 weights) - weighted_avg = (loss_details_tensor * weights.view(-1, 1)).sum(dim=0) / weights.sum() + subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] + logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) - subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] - logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) avg_loss = sum(losses) / len(losses) logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}")