| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -119,38 +119,7 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    query_points: {query_points.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF损失函数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 确保points需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if not points.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        points = points.detach().requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # L1损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    l1_loss = F.l1_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 梯度约束损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        grad = torch.autograd.grad( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf.sum(),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            create_graph=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            retain_graph=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            allow_unused=True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        )[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if grad is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_constraint = F.mse_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.norm(grad, dim=-1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.ones_like(pred_sdf.squeeze(-1)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.warning(f"Gradient computation failed: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return l1_loss + grad_weight * grad_constraint | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 获取配置 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |