| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -197,14 +197,14 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfz_embed = Encoder1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=3, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            out_channels=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, 256), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layers_per_block=2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.edgez_embed = Encoder1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=3, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            out_channels=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, 256), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layers_per_block=2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -410,6 +410,11 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        B, Q = query_points.shape[:2]  # B: batch_size, Q: num_queries | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             # 确保query_points需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not query_points.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                query_points = query_points.detach().requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. B-rep特征编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            brep_features = self.brep_embedder( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_ncs=edge_ncs,         # [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -439,6 +444,10 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 6. SDF预测 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sdf = self.sdf_head(combined_features)  # [B, Q, 1] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if not sdf.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                logger.warning("SDF output does not require grad!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					           | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return sdf | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -456,19 +465,34 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF损失函数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 确保points需要梯度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if not points.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        points = points.detach().requires_grad_(True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # L1损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    l1_loss = F.l1_loss(pred_sdf, gt_sdf) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 梯度约束损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    grad = torch.autograd.grad( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        pred_sdf.sum(),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        create_graph=True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    )[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    grad_constraint = F.mse_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        torch.norm(grad, dim=-1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        torch.ones_like(pred_sdf.squeeze(-1)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 梯度约束损失 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        grad = torch.autograd.grad( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            pred_sdf.sum(),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            points, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            create_graph=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            retain_graph=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            allow_unused=True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        )[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if grad is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_constraint = F.mse_loss( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.norm(grad, dim=-1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.ones_like(pred_sdf.squeeze(-1)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.warning(f"Gradient computation failed: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return l1_loss + grad_weight * grad_constraint | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |