|
|
@ -223,9 +223,9 @@ def gradient(inputs, outputs): |
|
|
|
# 修正维度切片方式 |
|
|
|
if points_grad is None: |
|
|
|
return torch.zeros_like(inputs[:, -3:]) # 处理空梯度情况 |
|
|
|
|
|
|
|
#logger.debug(f"points_grad:{points_grad},shape:{points_grad.shape}") |
|
|
|
# 添加安全截取和归一化 |
|
|
|
coord_grad = points_grad[:, -3:] if points_grad.shape[1] >=3 else points_grad |
|
|
|
coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-6) # 安全归一化 |
|
|
|
coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-16) # 安全归一化 |
|
|
|
|
|
|
|
return coord_grad |