| 
						
						
							
								
							
						
						
					 | 
					@ -315,27 +315,29 @@ class SDFHead(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return self.mlp(x) | 
					 | 
					 | 
					        return self.mlp(x) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class BRepToSDF(nn.Module): | 
					 | 
					 | 
					class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __init__( | 
					 | 
					 | 
					    def __init__(self, config=None): | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        self, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        brep_feature_dim: int = 48, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        use_cf: bool = True, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        embed_dim: int = 768, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        latent_dim: int = 256 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    ): | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        super().__init__() | 
					 | 
					 | 
					        super().__init__() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 获取配置 | 
					 | 
					 | 
					        # 获取配置 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.config = get_default_config() | 
					 | 
					 | 
					        if config is None: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.embed_dim = embed_dim | 
					 | 
					 | 
					            self.config = get_default_config() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            self.config = config | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 从配置中读取参数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.embed_dim = self.config.model.embed_dim | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.brep_feature_dim = self.config.model.brep_feature_dim | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.latent_dim = self.config.model.latent_dim | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.use_cf = self.config.model.use_cf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 1. 查询点编码器 | 
					 | 
					 | 
					        # 1. 查询点编码器 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.query_encoder = nn.Sequential( | 
					 | 
					 | 
					        self.query_encoder = nn.Sequential( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nn.Linear(3, embed_dim//4), | 
					 | 
					 | 
					            nn.Linear(3, self.embed_dim//4), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nn.LayerNorm(embed_dim//4), | 
					 | 
					 | 
					            nn.LayerNorm(self.embed_dim//4), | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            nn.ReLU(), | 
					 | 
					 | 
					            nn.ReLU(), | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nn.Linear(embed_dim//4, embed_dim//2), | 
					 | 
					 | 
					            nn.Linear(self.embed_dim//4, self.embed_dim//2), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nn.LayerNorm(embed_dim//2), | 
					 | 
					 | 
					            nn.LayerNorm(self.embed_dim//2), | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            nn.ReLU(), | 
					 | 
					 | 
					            nn.ReLU(), | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            nn.Linear(embed_dim//2, embed_dim) | 
					 | 
					 | 
					            nn.Linear(self.embed_dim//2, self.embed_dim) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 2. B-rep特征编码器 | 
					 | 
					 | 
					        # 2. B-rep特征编码器 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -343,12 +345,12 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 3. 特征融合Transformer | 
					 | 
					 | 
					        # 3. 特征融合Transformer | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.transformer = SDFTransformer( | 
					 | 
					 | 
					        self.transformer = SDFTransformer( | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            embed_dim=embed_dim, | 
					 | 
					 | 
					            embed_dim=self.embed_dim, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            num_layers=6 | 
					 | 
					 | 
					            num_layers=6  # 这个参数也可以移到配置文件中 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 4. SDF预测头 | 
					 | 
					 | 
					        # 4. SDF预测头 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.sdf_head = SDFHead(embed_dim=embed_dim*2) | 
					 | 
					 | 
					        self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): | 
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """B-rep到SDF的前向传播 | 
					 | 
					 | 
					        """B-rep到SDF的前向传播 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -435,88 +437,47 @@ def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 获取配置 | 
					 | 
					 | 
					    # 获取配置 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    config = get_default_config() | 
					 | 
					 | 
					    config = get_default_config() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 从配置初始化模型 | 
					 | 
					 | 
					    # 初始化模型 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    model = BRepToSDF( | 
					 | 
					 | 
					    model = BRepToSDF(config=config) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        brep_feature_dim=config.model.brep_feature_dim,  # 48 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        use_cf=config.model.use_cf,                      # True | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        embed_dim=config.model.embed_dim,                # 768 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        latent_dim=config.model.latent_dim               # 256 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    ) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 从配置获取数据参数 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    batch_size = config.train.batch_size  # 32 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_surfs = config.data.max_face      # 64 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_edges = config.data.max_edge      # 64 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_verts = 8                         # 顶点数保持固定 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_queries = 1000                    # 查询点数保持固定 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 更新测试数据维度 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    edge_ncs = torch.randn( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_surfs,      # max_face | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_edges,      # max_edge | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        config.model.num_edge_points, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        3 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    )  # [B, max_face, max_edge, num_edge_points, 3] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    edge_pos = torch.randn( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_surfs, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_edges, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        6 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    )  # [B, max_face, max_edge, 6] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    edge_mask = torch.ones( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_surfs, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_edges, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        dtype=torch.bool | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    )  # [B, max_face, max_edge] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    surf_ncs = torch.randn( | 
					 | 
					 | 
					    # 从配置获取参数 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					    batch_size = config.train.batch_size | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_surfs, | 
					 | 
					 | 
					    max_face = config.data.max_face | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        config.model.num_surf_points, | 
					 | 
					 | 
					    max_edge = config.data.max_edge | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        3 | 
					 | 
					 | 
					    num_surf_points = config.model.num_surf_points | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    )  # [B, max_face, num_surf_points, 3] | 
					 | 
					 | 
					    num_edge_points = config.model.num_edge_points | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    surf_pos = torch.randn( | 
					 | 
					 | 
					    # 生成测试数据 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					    test_data = { | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_surfs, | 
					 | 
					 | 
					        'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        6 | 
					 | 
					 | 
					        'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    )  # [B, max_face, 6] | 
					 | 
					 | 
					        'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'surf_pos': torch.randn(batch_size, max_face, 6), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'query_points': torch.randn(batch_size, 1000, 3)  # 1000个查询点 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    vertex_pos = torch.randn( | 
					 | 
					 | 
					    # 打印输入数据形状 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        batch_size, | 
					 | 
					 | 
					    logger.info("Input shapes:") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_surfs, | 
					 | 
					 | 
					    for name, tensor in test_data.items(): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_edges, | 
					 | 
					 | 
					        logger.info(f"  {name}: {tensor.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        2, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        3 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    )  # [B, max_face, max_edge, 2, 3] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    query_points = torch.randn(batch_size, num_queries, 3) | 
					 | 
					 | 
					    # 前向传播 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					     | 
					 | 
					 | 
					    try: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 更新前向传播调用 | 
					 | 
					 | 
					        sdf = model(**test_data) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    sdf = model( | 
					 | 
					 | 
					        logger.info(f"\nOutput SDF shape: {sdf.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_ncs=edge_ncs, | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_pos=edge_pos, | 
					 | 
					 | 
					        # 计算模型参数量 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_mask=edge_mask, | 
					 | 
					 | 
					        total_params = sum(p.numel() for p in model.parameters()) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        surf_ncs=surf_ncs, | 
					 | 
					 | 
					        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        surf_pos=surf_pos, | 
					 | 
					 | 
					        logger.info(f"\nModel statistics:") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        vertex_pos=vertex_pos, | 
					 | 
					 | 
					        logger.info(f"  Total parameters: {total_params:,}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        query_points=query_points | 
					 | 
					 | 
					        logger.info(f"  Trainable parameters: {trainable_params:,}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    ) | 
					 | 
					 | 
					         | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					     | 
					 | 
					 | 
					    except Exception as e: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 更新打印信息 | 
					 | 
					 | 
					        logger.error(f"Error during forward pass: {str(e)}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print("\nInput shapes:") | 
					 | 
					 | 
					        raise | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"edge_ncs: {edge_ncs.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"edge_pos: {edge_pos.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"edge_mask: {edge_mask.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"surf_ncs: {surf_ncs.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"surf_pos: {surf_pos.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"vertex_pos: {vertex_pos.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"query_points: {query_points.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"\nOutput SDF shape: {sdf.shape}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					if __name__ == "__main__": | 
					 | 
					 | 
					if __name__ == "__main__": | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    main() | 
					 | 
					 | 
					    main() |