| 
						
						
							
								
							
						
						
					 | 
					@ -5,6 +5,7 @@ import torch.nn.functional as F | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from dataclasses import dataclass | 
					 | 
					 | 
					from dataclasses import dataclass | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from typing import Dict, Optional, Tuple, Union | 
					 | 
					 | 
					from typing import Dict, Optional, Tuple, Union | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
					 | 
					 | 
					from brep2sdf.config.default_config import get_default_config | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					from brep2sdf.utils.logger import logger | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class ResConvBlock(nn.Module): | 
					 | 
					 | 
					class ResConvBlock(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """残差卷积块""" | 
					 | 
					 | 
					    """残差卷积块""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -119,16 +120,24 @@ class Encoder1D(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class BRepFeatureEmbedder(nn.Module): | 
					 | 
					 | 
					class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """B-rep特征嵌入器""" | 
					 | 
					 | 
					    """B-rep特征嵌入器""" | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __init__(self, use_cf: bool = True): | 
					 | 
					 | 
					    def __init__(self, config=None): | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        super().__init__() | 
					 | 
					 | 
					        super().__init__() | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 获取配置 | 
					 | 
					 | 
					        if config is None: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.config = get_default_config() | 
					 | 
					 | 
					            self.config = get_default_config() | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.embed_dim = 768 | 
					 | 
					 | 
					        else: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.use_cf = use_cf | 
					 | 
					 | 
					            self.config = config | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.num_surf_points = self.config.model.num_surf_points | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.num_edge_points = self.config.model.num_edge_points | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.embed_dim = self.config.model.embed_dim | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.use_cf = self.config.model.use_cf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 使用配置中的采样点数 | 
					 | 
					 | 
					        # 打印初始化信息 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.num_surf_points = self.config.model.num_surf_points  # 16 | 
					 | 
					 | 
					        logger.info(f"BRepFeatureEmbedder config:") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.num_edge_points = self.config.model.num_edge_points  # 4 | 
					 | 
					 | 
					        logger.info(f"  num_surf_points: {self.num_surf_points}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  num_edge_points: {self.num_edge_points}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  embed_dim: {self.embed_dim}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  use_cf: {self.use_cf}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # Transformer编码器层 | 
					 | 
					 | 
					        # Transformer编码器层 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        layer = nn.TransformerEncoderLayer( | 
					 | 
					 | 
					        layer = nn.TransformerEncoderLayer( | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -182,59 +191,93 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            nn.Linear(self.embed_dim, self.embed_dim), | 
					 | 
					 | 
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): | 
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """B-rep特征嵌入器的前向传播 | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					        Args: | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            surf_z: 表面点云 [B, N, num_surf_points, 3] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            edge_z: 边点云 [B, M, num_edge_points, 3] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            surf_p: 表面点 [B, N, 6] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            edge_p: 边点 [B, M, 6] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            vert_p: 顶点点 [B, K, 6] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mask: 注意力掩码 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 获取批次大小和其他维度 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        B = surf_z.size(0) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        N = surf_z.size(1) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        M = edge_z.size(1) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        K = vert_p.size(1) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 重塑点云数据用于1D编码器 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        surf_z = surf_z.reshape(B*N, self.num_surf_points, 3).transpose(1, 2)  # [B*N, 3, num_surf_points] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_z = edge_z.reshape(B*M, self.num_edge_points, 3).transpose(1, 2)  # [B*M, 3, num_edge_points] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 特征嵌入 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        surf_embeds = self.surfz_embed(surf_z)  # [B*N, embed_dim, num_points] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_embeds = self.edgez_embed(edge_z)  # [B*M, embed_dim, num_points] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 全局池化得到每个面/边的特征 | 
					 | 
					 | 
					        Args: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        surf_embeds = surf_embeds.mean(dim=-1)  # [B*N, embed_dim] | 
					 | 
					 | 
					            edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_embeds = edge_embeds.mean(dim=-1)  # [B*M, embed_dim] | 
					 | 
					 | 
					            edge_pos: 边位置 [B, max_face, max_edge, 6] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            edge_mask: 边掩码 [B, max_face, max_edge] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 重塑回批次维度 | 
					 | 
					 | 
					            surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        surf_embeds = surf_embeds.reshape(B, N, -1)  # [B, N, embed_dim] | 
					 | 
					 | 
					            surf_pos: 面位置 [B, max_face, 6] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_embeds = edge_embeds.reshape(B, M, -1)  # [B, M, embed_dim] | 
					 | 
					 | 
					            vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 点嵌入 | 
					 | 
					 | 
					        Returns: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        surf_p_embeds = self.surfp_embed(surf_p)  # [B, N, embed_dim] | 
					 | 
					 | 
					            embeds: [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        edge_p_embeds = self.edgep_embed(edge_p)  # [B, M, embed_dim] | 
					 | 
					 | 
					        """ | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        vert_p_embeds = self.vertp_embed(vert_p)  # [B, K, embed_dim] | 
					 | 
					 | 
					        B = self.config.train.batch_size | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_face = self.config.data.max_face | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        max_edge = self.config.data.max_edge | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 组合所有嵌入 | 
					 | 
					 | 
					        try: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        if self.use_cf: | 
					 | 
					 | 
					            # 1. 处理边特征 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            embeds = torch.cat([ | 
					 | 
					 | 
					            # 重塑边点云以适应1D编码器 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                surf_embeds + surf_p_embeds, | 
					 | 
					 | 
					            edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2)  # [B*max_face*max_edge, 3, num_edge_points] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                edge_embeds + edge_p_embeds, | 
					 | 
					 | 
					            edge_embeds = self.edgez_embed(edge_ncs)  # [B*max_face*max_edge, embed_dim, num_edge_points] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                vert_p_embeds | 
					 | 
					 | 
					            edge_embeds = edge_embeds.mean(dim=-1)    # [B*max_face*max_edge, embed_dim] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            ], dim=1)  # [B, N+M+K, embed_dim] | 
					 | 
					 | 
					            edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        else: | 
					 | 
					 | 
					             | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            embeds = torch.cat([ | 
					 | 
					 | 
					            # 2. 处理面特征 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                surf_p_embeds, | 
					 | 
					 | 
					            surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2)  # [B*max_face, 3, num_surf_points] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                edge_p_embeds, | 
					 | 
					 | 
					            surf_embeds = self.surfz_embed(surf_ncs)  # [B*max_face, embed_dim, num_surf_points] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                vert_p_embeds | 
					 | 
					 | 
					            surf_embeds = surf_embeds.mean(dim=-1)    # [B*max_face, embed_dim] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            ], dim=1)  # [B, N+M+K, embed_dim] | 
					 | 
					 | 
					            surf_embeds = surf_embeds.reshape(B, max_face, -1)  # [B, max_face, embed_dim] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 3. 处理位置编码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 边位置编码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            edge_pos = edge_pos.reshape(B*max_face*max_edge, -1)  # [B*max_face*max_edge, 6] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            edge_p_embeds = self.edgep_embed(edge_pos)  # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 面位置编码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            surf_p_embeds = self.surfp_embed(surf_pos)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        output = self.transformer(embeds, src_key_padding_mask=mask) | 
					 | 
					 | 
					            # 4. 组合特征 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return output | 
					 | 
					 | 
					            if self.use_cf: | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 边特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_features = edge_embeds + edge_p_embeds  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_features = edge_features.reshape(B, max_face*max_edge, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 面特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                surf_features = surf_embeds + surf_p_embeds  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 组合所有特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                embeds = torch.cat([ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    surf_features   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 只使用位置编码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                embeds = torch.cat([ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                    surf_p_embeds   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 5. 处理掩码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if edge_mask is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                # 扩展掩码以匹配特征维度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_mask = edge_mask.reshape(B, -1)  # [B, max_face*max_edge] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool)  # [B, max_face] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, max_face*(max_edge+1)] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                mask = None | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 6. Transformer处理 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            return output.transpose(0, 1)  # 确保输出维度为 [B, seq_len, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error in BRepFeatureEmbedder forward pass:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"  Error message: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"  Input shapes:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_ncs: {edge_ncs.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_pos: {edge_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_mask: {edge_mask.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    surf_ncs: {surf_ncs.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    surf_pos: {surf_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    vertex_pos: {vertex_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class SDFTransformer(nn.Module): | 
					 | 
					 | 
					class SDFTransformer(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """SDF Transformer编码器""" | 
					 | 
					 | 
					    """SDF Transformer编码器""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -296,7 +339,7 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					        ) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 2. B-rep特征编码器 | 
					 | 
					 | 
					        # 2. B-rep特征编码器 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) | 
					 | 
					 | 
					        self.brep_embedder = BRepFeatureEmbedder() | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 3. 特征融合Transformer | 
					 | 
					 | 
					        # 3. 特征融合Transformer | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.transformer = SDFTransformer( | 
					 | 
					 | 
					        self.transformer = SDFTransformer( | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -307,45 +350,68 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 4. SDF预测头 | 
					 | 
					 | 
					        # 4. SDF预测头 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.sdf_head = SDFHead(embed_dim=embed_dim*2) | 
					 | 
					 | 
					        self.sdf_head = SDFHead(embed_dim=embed_dim*2) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): | 
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """B-rep到SDF的前向传播 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        Args: | 
					 | 
					 | 
					        Args: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            surf_z: 表面特征 [B, N, 48] | 
					 | 
					 | 
					            edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            edge_z: 边特征 [B, M, 12] | 
					 | 
					 | 
					            edge_pos: 边位置 [B, max_face, max_edge, 6] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            surf_p: 表面点 [B, N, 6] | 
					 | 
					 | 
					            edge_mask: 边掩码 [B, max_face, max_edge] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            edge_p: 边点 [B, M, 6] | 
					 | 
					 | 
					            surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            vert_p: 顶点点 [B, K, 6] | 
					 | 
					 | 
					            surf_pos: 面位置 [B, max_face, 6] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            query_points: 查询点 [B, Q, 3] | 
					 | 
					 | 
					            vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            mask: 注意力掩码 | 
					 | 
					 | 
					            query_points: 查询点 [B, num_queries, 3] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            data_class: (可选) 类别标签 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        Returns: | 
					 | 
					 | 
					        Returns: | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            sdf: [B, Q, 1] | 
					 | 
					 | 
					            sdf: 预测的SDF值 [B, num_queries, 1] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        B, Q, _ = query_points.shape | 
					 | 
					 | 
					        B, Q = query_points.shape[:2]  # B: batch_size, Q: num_queries | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 1. B-rep特征嵌入 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        brep_features = self.feature_embedder( | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            surf_z, edge_z, surf_p, edge_p, vert_p, mask | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        )  # [B, N+M+K, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 2. 查询点编码 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        query_features = self.query_encoder(query_points)  # [B, Q, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 3. 提取全局特征 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        global_features = brep_features.mean(dim=1)  # [B, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 4. 为每个查询点准备特征 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1)  # [B, Q, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 5. 连接查询点特征和全局特征 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        combined_features = torch.cat([ | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            expanded_features,  # [B, Q, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            query_features     # [B, Q, embed_dim] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ], dim=-1)  # [B, Q, embed_dim*2] | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        # 6. SDF预测 | 
					 | 
					 | 
					        try: | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        sdf = self.sdf_head(combined_features)  # [B, Q, 1] | 
					 | 
					 | 
					            # 1. B-rep特征编码 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					            brep_features = self.brep_embedder( | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        return sdf | 
					 | 
					 | 
					                edge_ncs=edge_ncs,         # [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_pos=edge_pos,         # [B, max_face, max_edge, 6] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                edge_mask=edge_mask,       # [B, max_face, max_edge] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                surf_ncs=surf_ncs,         # [B, max_face, num_surf_points, 3] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                surf_pos=surf_pos,         # [B, max_face, 6] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                vertex_pos=vertex_pos,     # [B, max_face, max_edge, 2, 3] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                data_class=data_class | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            )  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 2. 查询点编码 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            query_features = self.query_encoder(query_points)  # [B, Q, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 3. 提取全局特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            global_features = brep_features.mean(dim=1)  # [B, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 4. 为每个查询点准备特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1)  # [B, Q, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 5. 连接查询点特征和全局特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            combined_features = torch.cat([ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                expanded_features,  # [B, Q, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                query_features     # [B, Q, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            ], dim=-1)  # [B, Q, embed_dim*2] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 6. SDF预测 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            sdf = self.sdf_head(combined_features)  # [B, Q, 1] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            return sdf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        except Exception as e: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"Error in BRepToSDF forward pass:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"  Error message: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"  Input shapes:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_ncs: {edge_ncs.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_pos: {edge_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    edge_mask: {edge_mask.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    surf_ncs: {surf_ncs.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    surf_pos: {surf_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    vertex_pos: {vertex_pos.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.error(f"    query_points: {query_points.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """SDF损失函数""" | 
					 | 
					 | 
					    """SDF损失函数""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -384,47 +450,73 @@ def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_verts = 8                         # 顶点数保持固定 | 
					 | 
					 | 
					    num_verts = 8                         # 顶点数保持固定 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    num_queries = 1000                    # 查询点数保持固定 | 
					 | 
					 | 
					    num_queries = 1000                    # 查询点数保持固定 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 生成示例数据 | 
					 | 
					 | 
					    # 更新测试数据维度 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    surf_z = torch.randn( | 
					 | 
					 | 
					    edge_ncs = torch.randn( | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        batch_size,  | 
					 | 
					 | 
					        batch_size, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_surfs,  | 
					 | 
					 | 
					        num_surfs,      # max_face | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        config.model.num_surf_points,  # 16 | 
					 | 
					 | 
					        num_edges,      # max_edge | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        config.model.num_edge_points, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        3 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    )  # [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    edge_pos = torch.randn( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_surfs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_edges, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        6 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    )  # [B, max_face, max_edge, 6] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    edge_mask = torch.ones( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_surfs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_edges, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        dtype=torch.bool | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    )  # [B, max_face, max_edge] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    surf_ncs = torch.randn( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_surfs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        config.model.num_surf_points, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        3 | 
					 | 
					 | 
					        3 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    )  # [B, N, num_surf_points, 3] | 
					 | 
					 | 
					    )  # [B, max_face, num_surf_points, 3] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    surf_pos = torch.randn( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        batch_size, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        num_surfs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        6 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    )  # [B, max_face, 6] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    edge_z = torch.randn( | 
					 | 
					 | 
					    vertex_pos = torch.randn( | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        batch_size,  | 
					 | 
					 | 
					        batch_size, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        num_edges,  | 
					 | 
					 | 
					        num_surfs, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					        config.model.num_edge_points,  # 4 | 
					 | 
					 | 
					        num_edges, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        2, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        3 | 
					 | 
					 | 
					        3 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    )  # [B, M, num_edge_points, 3] | 
					 | 
					 | 
					    )  # [B, max_face, max_edge, 2, 3] | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 其他输入 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    surf_p = torch.randn(batch_size, num_surfs, 6) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    edge_p = torch.randn(batch_size, num_edges, 6) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    vert_p = torch.randn(batch_size, num_verts, 6) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    query_points = torch.randn(batch_size, num_queries, 3) | 
					 | 
					 | 
					    query_points = torch.randn(batch_size, num_queries, 3) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 前向传播 | 
					 | 
					 | 
					    # 更新前向传播调用 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) | 
					 | 
					 | 
					    sdf = model( | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					     | 
					 | 
					 | 
					        edge_ncs=edge_ncs, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    # 打印形状信息和配置信息 | 
					 | 
					 | 
					        edge_pos=edge_pos, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print("\nConfiguration:") | 
					 | 
					 | 
					        edge_mask=edge_mask, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"Batch Size: {batch_size}") | 
					 | 
					 | 
					        surf_ncs=surf_ncs, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"Embed Dimension: {config.model.embed_dim}") | 
					 | 
					 | 
					        surf_pos=surf_pos, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"Surface Points: {config.model.num_surf_points}") | 
					 | 
					 | 
					        vertex_pos=vertex_pos, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"Edge Points: {config.model.num_edge_points}") | 
					 | 
					 | 
					        query_points=query_points | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"Max Faces: {config.data.max_face}") | 
					 | 
					 | 
					    ) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					    print(f"Max Edges: {config.data.max_edge}") | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					     | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 更新打印信息 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    print("\nInput shapes:") | 
					 | 
					 | 
					    print("\nInput shapes:") | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"surf_z: {surf_z.shape}")      # [32, 64, 16, 3] | 
					 | 
					 | 
					    print(f"edge_ncs: {edge_ncs.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"edge_z: {edge_z.shape}")      # [32, 64, 4, 3] | 
					 | 
					 | 
					    print(f"edge_pos: {edge_pos.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"surf_p: {surf_p.shape}")      # [32, 64, 6] | 
					 | 
					 | 
					    print(f"edge_mask: {edge_mask.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"edge_p: {edge_p.shape}")      # [32, 64, 6] | 
					 | 
					 | 
					    print(f"surf_ncs: {surf_ncs.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"vert_p: {vert_p.shape}")      # [32, 8, 6] | 
					 | 
					 | 
					    print(f"surf_pos: {surf_pos.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"query_points: {query_points.shape}")  # [32, 1000, 3] | 
					 | 
					 | 
					    print(f"vertex_pos: {vertex_pos.shape}") | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					    print(f"\nOutput SDF shape: {sdf.shape}")     # [32, 1000, 1] | 
					 | 
					 | 
					    print(f"query_points: {query_points.shape}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    print(f"\nOutput SDF shape: {sdf.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					if __name__ == "__main__": | 
					 | 
					 | 
					if __name__ == "__main__": | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    main() | 
					 | 
					 | 
					    main() |