|
|
@ -173,19 +173,24 @@ def train_step(model, batch, optimizer, criterion): |
|
|
|
# 确保模型处于训练模式 |
|
|
|
model.train() |
|
|
|
|
|
|
|
# 检查并设置所有参数的requires_grad |
|
|
|
# 1. 深拷贝输入数据,确保每次都是新的张量 |
|
|
|
batch = { |
|
|
|
k: v.clone().detach().requires_grad_(True) if isinstance(v, torch.Tensor) and v.dtype != torch.bool |
|
|
|
else v |
|
|
|
for k, v in batch.items() |
|
|
|
} |
|
|
|
|
|
|
|
# 2. 确保所有模型参数都需要梯度 |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if not param.requires_grad: |
|
|
|
logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True") |
|
|
|
param.requires_grad = True |
|
|
|
|
|
|
|
# 将所有输入转为requires_grad=True |
|
|
|
batch['query_points'].requires_grad_(True) |
|
|
|
|
|
|
|
# 清零梯度 |
|
|
|
optimizer.zero_grad() |
|
|
|
# 3. 清零梯度(同时清除优化器和模型的梯度) |
|
|
|
model.zero_grad(set_to_none=True) # 使用set_to_none=True更彻底地清除梯度 |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
|
|
# 前向传播 |
|
|
|
# 4. 前向传播 |
|
|
|
pred_sdf = model( |
|
|
|
edge_ncs=batch['edge_ncs'], |
|
|
|
edge_pos=batch['edge_pos'], |
|
|
@ -196,33 +201,43 @@ def train_step(model, batch, optimizer, criterion): |
|
|
|
query_points=batch['query_points'] |
|
|
|
) |
|
|
|
|
|
|
|
# 计算损失 |
|
|
|
# 5. 检查预测输出是否需要梯度 |
|
|
|
if not pred_sdf.requires_grad: |
|
|
|
logger.warning("预测输出不需要梯度!") |
|
|
|
pred_sdf.requires_grad_(True) |
|
|
|
|
|
|
|
# 6. 计算损失 |
|
|
|
loss = criterion(pred_sdf, batch['gt_sdf']) |
|
|
|
|
|
|
|
# 检查损失是否有效 |
|
|
|
# 7. 检查损失 |
|
|
|
logger.info(f"Loss value: {loss.item()}") |
|
|
|
logger.info(f"Loss requires grad: {loss.requires_grad}") |
|
|
|
|
|
|
|
if not torch.isfinite(loss): |
|
|
|
logger.error(f"损失值无效: {loss.item()}") |
|
|
|
raise ValueError("损失值无效") |
|
|
|
|
|
|
|
# 反向传播 |
|
|
|
loss.backward() |
|
|
|
# 8. 反向传播 |
|
|
|
loss.backward(retain_graph=False) # 使用retain_graph=False确保清理计算图 |
|
|
|
|
|
|
|
# 检查梯度 |
|
|
|
total_norm = 0 |
|
|
|
# 9. 详细的梯度检查 |
|
|
|
logger.info("\n梯度信息:") |
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if param.grad is not None: |
|
|
|
param_norm = param.grad.data.norm(2) |
|
|
|
total_norm += param_norm.item() ** 2 |
|
|
|
logger.info(f" {name}: grad_norm = {param_norm.item()}") |
|
|
|
grad_norm = param.grad.norm().item() |
|
|
|
grad_mean = param.grad.mean().item() |
|
|
|
grad_std = param.grad.std().item() |
|
|
|
logger.info(f" {name}:") |
|
|
|
logger.info(f" norm: {grad_norm:.6f}") |
|
|
|
logger.info(f" mean: {grad_mean:.6f}") |
|
|
|
logger.info(f" std: {grad_std:.6f}") |
|
|
|
else: |
|
|
|
logger.warning(f" {name}: No gradient!") |
|
|
|
total_norm = total_norm ** 0.5 |
|
|
|
logger.info(f"梯度总范数: {total_norm}") |
|
|
|
logger.warning(f" {name}: 没有梯度!") |
|
|
|
|
|
|
|
# 梯度裁剪(可选) |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
# 10. 梯度裁剪 |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) |
|
|
|
|
|
|
|
# 更新参数 |
|
|
|
# 11. 更新参数 |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
return loss.item() |
|
|
@ -260,7 +275,7 @@ def train(model, config, num_epochs=10): |
|
|
|
} |
|
|
|
|
|
|
|
# 训练循环 |
|
|
|
for epoch in range(2): |
|
|
|
for epoch in range(num_epochs): |
|
|
|
try: |
|
|
|
loss = train_step(model, batch, optimizer, criterion) |
|
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}") |
|
|
|