From 51db266834379d0206409138066a1329438f5738 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 30 Nov 2024 16:49:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20network=20=E6=B5=8B=E8=AF=95=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=20tofix:=20some=20norm=20=3D0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/network.py | 61 ++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index b18dd42..254ed87 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -173,19 +173,24 @@ def train_step(model, batch, optimizer, criterion): # 确保模型处于训练模式 model.train() - # 检查并设置所有参数的requires_grad + # 1. 深拷贝输入数据,确保每次都是新的张量 + batch = { + k: v.clone().detach().requires_grad_(True) if isinstance(v, torch.Tensor) and v.dtype != torch.bool + else v + for k, v in batch.items() + } + + # 2. 确保所有模型参数都需要梯度 for name, param in model.named_parameters(): if not param.requires_grad: logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True") param.requires_grad = True - # 将所有输入转为requires_grad=True - batch['query_points'].requires_grad_(True) - - # 清零梯度 - optimizer.zero_grad() + # 3. 清零梯度(同时清除优化器和模型的梯度) + model.zero_grad(set_to_none=True) # 使用set_to_none=True更彻底地清除梯度 + optimizer.zero_grad(set_to_none=True) - # 前向传播 + # 4. 前向传播 pred_sdf = model( edge_ncs=batch['edge_ncs'], edge_pos=batch['edge_pos'], @@ -196,33 +201,43 @@ def train_step(model, batch, optimizer, criterion): query_points=batch['query_points'] ) - # 计算损失 + # 5. 检查预测输出是否需要梯度 + if not pred_sdf.requires_grad: + logger.warning("预测输出不需要梯度!") + pred_sdf.requires_grad_(True) + + # 6. 计算损失 loss = criterion(pred_sdf, batch['gt_sdf']) - # 检查损失是否有效 + # 7. 检查损失 + logger.info(f"Loss value: {loss.item()}") + logger.info(f"Loss requires grad: {loss.requires_grad}") + if not torch.isfinite(loss): logger.error(f"损失值无效: {loss.item()}") raise ValueError("损失值无效") - # 反向传播 - loss.backward() + # 8. 反向传播 + loss.backward(retain_graph=False) # 使用retain_graph=False确保清理计算图 - # 检查梯度 - total_norm = 0 + # 9. 详细的梯度检查 + logger.info("\n梯度信息:") for name, param in model.named_parameters(): if param.grad is not None: - param_norm = param.grad.data.norm(2) - total_norm += param_norm.item() ** 2 - logger.info(f" {name}: grad_norm = {param_norm.item()}") + grad_norm = param.grad.norm().item() + grad_mean = param.grad.mean().item() + grad_std = param.grad.std().item() + logger.info(f" {name}:") + logger.info(f" norm: {grad_norm:.6f}") + logger.info(f" mean: {grad_mean:.6f}") + logger.info(f" std: {grad_std:.6f}") else: - logger.warning(f" {name}: No gradient!") - total_norm = total_norm ** 0.5 - logger.info(f"梯度总范数: {total_norm}") + logger.warning(f" {name}: 没有梯度!") - # 梯度裁剪(可选) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # 10. 梯度裁剪 + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) - # 更新参数 + # 11. 更新参数 optimizer.step() return loss.item() @@ -260,7 +275,7 @@ def train(model, config, num_epochs=10): } # 训练循环 - for epoch in range(2): + for epoch in range(num_epochs): try: loss = train_step(model, batch, optimizer, criterion) logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}")