| 
						
						
							
								
							
						
						
					 | 
					@ -251,7 +251,6 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 2. 处理边特征 | 
					 | 
					 | 
					        # 2. 处理边特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.info(f"edge_ncs shape: {edge_ncs.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -441,6 +440,13 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        B, Q = query_points.shape[:2]  # B: batch_size, Q: num_queries | 
					 | 
					 | 
					        B, Q = query_points.shape[:2]  # B: batch_size, Q: num_queries | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"query_points requires_grad: {query_points.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"edge_ncs requires_grad: {edge_ncs.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"edge_pos requires_grad: {edge_pos.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"edge_mask requires_grad: {edge_mask.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"surf_ncs requires_grad: {surf_ncs.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        #logger.info(f"surf_pos requires_grad: {surf_pos.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        try: | 
					 | 
					 | 
					        try: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             # 确保query_points需要梯度 | 
					 | 
					 | 
					             # 确保query_points需要梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            if not query_points.requires_grad: | 
					 | 
					 | 
					            if not query_points.requires_grad: | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -544,13 +550,13 @@ def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 生成测试数据 | 
					 | 
					 | 
					    # 生成测试数据 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    test_data = { | 
					 | 
					 | 
					    test_data = { | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), | 
					 | 
					 | 
					        'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3, requires_grad=True), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), | 
					 | 
					 | 
					        'edge_pos': torch.randn(batch_size, max_face, max_edge, 6, requires_grad=True), | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), | 
					 | 
					 | 
					        'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), | 
					 | 
					 | 
					        'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3, requires_grad=True), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'surf_pos': torch.randn(batch_size, max_face, 6), | 
					 | 
					 | 
					        'surf_pos': torch.randn(batch_size, max_face, 6, requires_grad=True), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), | 
					 | 
					 | 
					        'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3, requires_grad=True), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        'query_points': torch.randn(batch_size, 1000, 3)  # 1000个查询点 | 
					 | 
					 | 
					        'query_points': torch.randn(batch_size, 1000, 3, requires_grad=True)  # 1000个查询点 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					    } | 
					 | 
					 | 
					    } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 打印输入数据形状 | 
					 | 
					 | 
					    # 打印输入数据形状 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |