Browse Source

fix: network 测试脚本 tofix: some

norm =0
main
mckay 4 months ago
parent
commit
51db266834
  1. 61
      brep2sdf/networks/network.py

61
brep2sdf/networks/network.py

@ -173,19 +173,24 @@ def train_step(model, batch, optimizer, criterion):
# 确保模型处于训练模式
model.train()
# 检查并设置所有参数的requires_grad
# 1. 深拷贝输入数据,确保每次都是新的张量
batch = {
k: v.clone().detach().requires_grad_(True) if isinstance(v, torch.Tensor) and v.dtype != torch.bool
else v
for k, v in batch.items()
}
# 2. 确保所有模型参数都需要梯度
for name, param in model.named_parameters():
if not param.requires_grad:
logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True")
param.requires_grad = True
# 将所有输入转为requires_grad=True
batch['query_points'].requires_grad_(True)
# 清零梯度
optimizer.zero_grad()
# 3. 清零梯度(同时清除优化器和模型的梯度)
model.zero_grad(set_to_none=True) # 使用set_to_none=True更彻底地清除梯度
optimizer.zero_grad(set_to_none=True)
# 前向传播
# 4. 前向传播
pred_sdf = model(
edge_ncs=batch['edge_ncs'],
edge_pos=batch['edge_pos'],
@ -196,33 +201,43 @@ def train_step(model, batch, optimizer, criterion):
query_points=batch['query_points']
)
# 计算损失
# 5. 检查预测输出是否需要梯度
if not pred_sdf.requires_grad:
logger.warning("预测输出不需要梯度!")
pred_sdf.requires_grad_(True)
# 6. 计算损失
loss = criterion(pred_sdf, batch['gt_sdf'])
# 检查损失是否有效
# 7. 检查损失
logger.info(f"Loss value: {loss.item()}")
logger.info(f"Loss requires grad: {loss.requires_grad}")
if not torch.isfinite(loss):
logger.error(f"损失值无效: {loss.item()}")
raise ValueError("损失值无效")
# 反向传播
loss.backward()
# 8. 反向传播
loss.backward(retain_graph=False) # 使用retain_graph=False确保清理计算图
# 检查梯度
total_norm = 0
# 9. 详细的梯度检查
logger.info("\n梯度信息:")
for name, param in model.named_parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
logger.info(f" {name}: grad_norm = {param_norm.item()}")
grad_norm = param.grad.norm().item()
grad_mean = param.grad.mean().item()
grad_std = param.grad.std().item()
logger.info(f" {name}:")
logger.info(f" norm: {grad_norm:.6f}")
logger.info(f" mean: {grad_mean:.6f}")
logger.info(f" std: {grad_std:.6f}")
else:
logger.warning(f" {name}: No gradient!")
total_norm = total_norm ** 0.5
logger.info(f"梯度总范数: {total_norm}")
logger.warning(f" {name}: 没有梯度!")
# 梯度裁剪(可选)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 10. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
# 更新参数
# 11. 更新参数
optimizer.step()
return loss.item()
@ -260,7 +275,7 @@ def train(model, config, num_epochs=10):
}
# 训练循环
for epoch in range(2):
for epoch in range(num_epochs):
try:
loss = train_step(model, batch, optimizer, criterion)
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}")

Loading…
Cancel
Save