| 
						
						
							
								
							
						
						
					 | 
					@ -119,38 +119,7 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.error(f"    query_points: {query_points.shape}") | 
					 | 
					 | 
					            logger.error(f"    query_points: {query_points.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            raise | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
					 | 
					 | 
					
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					    """SDF损失函数""" | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 确保points需要梯度 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    if not points.requires_grad: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        points = points.detach().requires_grad_(True) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # L1损失 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    l1_loss = F.l1_loss(pred_sdf, gt_sdf) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    try: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 梯度约束损失 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        grad = torch.autograd.grad( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            pred_sdf.sum(),  | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            points, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            create_graph=True, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            retain_graph=True, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            allow_unused=True | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        )[0] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if grad is not None: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            grad_constraint = F.mse_loss( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                torch.norm(grad, dim=-1), | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                torch.ones_like(pred_sdf.squeeze(-1)) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            ) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    except Exception as e: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.warning(f"Gradient computation failed: {str(e)}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    return l1_loss + grad_weight * grad_constraint | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def main(): | 
					 | 
					 | 
					def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 获取配置 | 
					 | 
					 | 
					    # 获取配置 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |