| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -76,7 +76,7 @@ class UNetMidBlock1D(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x = attn(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return x | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Encoder1D(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class _Encoder1D(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        in_channels: int = 3, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -115,46 +115,39 @@ class Encoder1D(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.conv_out(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return x | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Encoder1D_(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """一维编码器""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Encoder1D(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        in_channels: int = 3, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        out_channels: int = 256, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        block_out_channels: Tuple[int] = (64, 128, 256), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        layers_per_block: int = 2, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        num_points: int = 2,          # 输入点数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        in_channels: int = 3,          # xyz坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        out_dim: int = 6,         # 展平后维度 (2*3) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.num_points = num_points | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.in_channels = in_channels | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.out_dim = out_dim | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.down_blocks = nn.ModuleList([]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        in_ch = block_out_channels[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for out_ch in block_out_channels: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block = [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for _ in range(layers_per_block): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                block.append(ResConvBlock(in_ch, out_ch, out_ch)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                in_ch = out_ch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if out_ch != block_out_channels[-1]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                block.append(nn.AvgPool1d(2)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.down_blocks.append(nn.Sequential(*block)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        assert num_points * in_channels == out_dim, \ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            f"Flatten dimension {out_dim} must equal num_points({num_points}) * in_channels({in_channels})" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.mid_block = UNetMidBlock1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=block_out_channels[-1], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mid_channels=block_out_channels[-1], | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, x): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Args: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x: [B, F, E, num_points, channels]  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Returns: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x: [B, F, E, out_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        orig_shape = x.shape[:-2]  # [B, F, E] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.conv_out = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.GroupNorm(32, block_out_channels[-1]), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.SiLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 重塑以处理所有批次 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x.reshape(-1, self.num_points, self.in_channels)  # [B*F*E, num_points, channels] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 展平点和通道维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x.reshape(x.size(0), -1)  # [B*F*E, num_points*channels] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 恢复批次维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x.reshape(*orig_shape, self.out_dim)  # [B, F, E, out_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, x): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.conv_in(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for block in self.down_blocks: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x = block(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.mid_block(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.conv_out(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return x | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -178,55 +171,20 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"  embed_dim: {self.embed_dim}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"  use_cf: {self.use_cf}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # Transformer编码器层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        layer = nn.TransformerEncoderLayer( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            d_model=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nhead=8,  # 从12减少到8,使每个head的维度更大 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm_first=True,  # 改为True,先进行归一化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dim_feedforward=self.embed_dim * 4,  # 增大FFN维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dropout=0.1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            activation=F.gelu  # 使用GELU激活函数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加初始化方法 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        def _init_weights(module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if isinstance(module, nn.Linear): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if module.bias is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            elif isinstance(module, nn.LayerNorm): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.ones_(module.weight) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer = nn.TransformerEncoder( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layer, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_layers=6,  # 从12减少到6层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm=nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            enable_nested_tensor=False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 应用初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer.apply(_init_weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02)  # 最大序列长度设为1000 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_out_dim = 3*self.num_surf_points # 面特征展平后维度,8*3=24 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_out_dim = 3*self.num_edge_points # 边特征展平后维度,2*3=6 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfz_encoder = Encoder1D(num_points=self.num_surf_points,in_channels=3,out_dim=surf_out_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.edgez_encoder = Encoder1D(num_points=self.num_edge_points,in_channels=3,out_dim=edge_out_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 修改为处理[num_points, 3]形状的输入 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfz_embed = Encoder1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=3, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            out_channels=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layers_per_block=2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.edgez_embed = Encoder1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=3, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            out_channels=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            block_out_channels=(64, 128, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layers_per_block=2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfz_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.surfz_encoder,              # [B, 16, 3] -> [B, 48] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(surf_out_dim, self.embed_dim), # [B, 48] -> [B, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.SiLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 其他嵌入层保持不变 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfp_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(6, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -234,6 +192,16 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.edgez_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.edgez_encoder,              | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(edge_out_dim, self.embed_dim),  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.SiLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        )  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.edgep_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(6, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -243,86 +211,87 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 修改vertp_embed的结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.vertp_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(3, self.embed_dim // 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim // 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim // 2, self.embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(3, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.SiLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加一个额外的投影层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加 transformer 初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer = nn.TransformerEncoder( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            encoder_layer=nn.TransformerEncoderLayer( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                d_model=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nhead=8,  # 注意力头数,通常是embed_dim的因子 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                dim_feedforward=4*self.embed_dim,  # 前馈网络维度,通常是embed_dim的4倍 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                dropout=0.1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                activation='gelu', | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                batch_first=False  # 因为我们用了transpose(0,1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_layers=6  # transformer层数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        B, F, E, _, _ = edge_ncs.shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理顶点特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_features = vertex_pos.view(B*F*E*2, -1)  # 展平顶点坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertp_embed(vertex_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertex_proj(vertex_embed)  # 添加投影 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = vertex_embed.view(B, F, E, 2, -1)  # 恢复形状 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 确保顶点特征参与后续计算 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_features = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_embed.mean(dim=3)  # 将顶点特征平均池化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ], dim=-1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 处理边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 重塑边点云以适应1D编码器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2)  # [B*max_face*max_edge, 3, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B*max_face*max_edge, embed_dim, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = edge_embeds.mean(dim=-1)    # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = edge_embeds.reshape(B, F, E, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 处理面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2)  # [B*max_face, 3, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B*max_face, embed_dim, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = surf_embeds.mean(dim=-1)    # [B*max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = surf_embeds.reshape(B, F, -1)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 处理位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 边位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_pos = edge_pos.reshape(B*F*E, -1)  # [B*max_face*max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 面位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Args: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs: [B, F, E, num_edge_points, 3] - 边点云 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos: [B, F, E, 6] - 边位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask: [B, F, E] - 边掩码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs: [B, F, num_surf_points, 3] - 面点云 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos: [B, F, 6] - 面位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos: [B, F, E, 2, 3] - 顶点位置 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        B, F, E = edge_pos.shape[:3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 处理顶点特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertp_embed(vertex_pos[..., :3])  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertex_proj(vertex_embed)  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 处理边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info(f"edge_ncs shape: {edge_ncs.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 处理面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 4. 组合特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self.use_cf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_embeds + edge_p_embeds  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_embeds + edge_p_embeds + vertex_embed  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_features = surf_embeds + surf_p_embeds  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_features = surf_embeds + surf_p_embeds  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合所有特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 拼接所有特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_features   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_features   # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, F*E+F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 只使用位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_p_embeds.reshape(B, F*E, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_p_embeds.reshape(B, F*E, -1)  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_p_embeds   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_p_embeds   # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, F*E+F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 5. 处理掩码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if edge_mask is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 扩展掩码以匹配特征维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = edge_mask.reshape(B, -1)  # [B, max_face*max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool)  # [B, max_face] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, max_face*(max_edge+1)] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = edge_mask.reshape(B, -1)  # [B, F*E] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool)  # [B, F] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, F*E+F] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mask = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 6. Transformer处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output.transpose(0, 1)  # 确保输出维度为 [B, seq_len, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output.transpose(0, 1)  # [B, F*E+F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class SDFTransformer(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF Transformer编码器""" | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |