| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -180,19 +180,37 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # Transformer编码器层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        layer = nn.TransformerEncoderLayer( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            d_model=self.embed_dim,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nhead=12,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm_first=False, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dim_feedforward=1024,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dropout=0.1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            d_model=self.embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nhead=8,  # 从12减少到8,使每个head的维度更大 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm_first=True,  # 改为True,先进行归一化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dim_feedforward=self.embed_dim * 4,  # 增大FFN维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dropout=0.1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            activation=F.gelu  # 使用GELU激活函数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加初始化方法 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        def _init_weights(module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if isinstance(module, nn.Linear): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if module.bias is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    torch.nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            elif isinstance(module, nn.LayerNorm): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.ones_(module.weight) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                torch.nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer = nn.TransformerEncoder( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layer,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_layers=12,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layer, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_layers=6,  # 从12减少到6层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm=nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            enable_nested_tensor=False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 应用初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer.apply(_init_weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02)  # 最大序列长度设为1000 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 修改为处理[num_points, 3]形状的输入 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.surfz_embed = Encoder1D( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            in_channels=3, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -223,117 +241,162 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 修改vertp_embed的结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.vertp_embed = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(6, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.SiLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim, self.embed_dim), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(3, self.embed_dim // 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.LayerNorm(self.embed_dim // 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(self.embed_dim // 2, self.embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """B-rep特征嵌入器的前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Args: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos: 边位置 [B, max_face, max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask: 边掩码 [B, max_face, max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_pos: 面位置 [B, max_face, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加一个额外的投影层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Returns: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds: [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        B = self.config.train.batch_size | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_face = self.config.data.max_face | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        max_edge = self.config.data.max_edge | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        B, F, E, _, _ = edge_ncs.shape | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 1. 处理边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 重塑边点云以适应1D编码器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2)  # [B*max_face*max_edge, 3, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_embeds = self.edgez_embed(edge_ncs)  # [B*max_face*max_edge, embed_dim, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_embeds = edge_embeds.mean(dim=-1)    # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 2. 处理面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2)  # [B*max_face, 3, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_embeds = self.surfz_embed(surf_ncs)  # [B*max_face, embed_dim, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_embeds = surf_embeds.mean(dim=-1)    # [B*max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_embeds = surf_embeds.reshape(B, max_face, -1)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 3. 处理位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 边位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_pos = edge_pos.reshape(B*max_face*max_edge, -1)  # [B*max_face*max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_p_embeds = self.edgep_embed(edge_pos)  # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 面位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_p_embeds = self.surfp_embed(surf_pos)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 4. 组合特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if self.use_cf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features = edge_embeds + edge_p_embeds  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features = edge_features.reshape(B, max_face*max_edge, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_features = surf_embeds + surf_p_embeds  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 组合所有特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_features   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 只使用位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    surf_p_embeds   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 5. 处理掩码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if edge_mask is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 扩展掩码以匹配特征维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_mask = edge_mask.reshape(B, -1)  # [B, max_face*max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool)  # [B, max_face] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, max_face*(max_edge+1)] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                mask = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 处理顶点特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_features = vertex_pos.view(B*F*E*2, -1)  # 展平顶点坐标 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertp_embed(vertex_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertex_proj(vertex_embed)  # 添加投影 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = vertex_embed.view(B, F, E, 2, -1)  # 恢复形状 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 确保顶点特征参与后续计算 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_features = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            vertex_embed.mean(dim=3)  # 将顶点特征平均池化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ], dim=-1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 处理边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 重塑边点云以适应1D编码器 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2)  # [B*max_face*max_edge, 3, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B*max_face*max_edge, embed_dim, num_edge_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = edge_embeds.mean(dim=-1)    # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = edge_embeds.reshape(B, F, E, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 处理面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2)  # [B*max_face, 3, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B*max_face, embed_dim, num_surf_points] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = surf_embeds.mean(dim=-1)    # [B*max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = surf_embeds.reshape(B, F, -1)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 处理位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 边位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_pos = edge_pos.reshape(B*F*E, -1)  # [B*max_face*max_edge, 6] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B*max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1)  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 面位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 4. 组合特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self.use_cf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_embeds + edge_p_embeds  # [B, max_face, max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 6. Transformer处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return output.transpose(0, 1)  # 确保输出维度为 [B, seq_len, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_features = surf_embeds + surf_p_embeds  # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        except Exception as e: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"Error in BRepFeatureEmbedder forward pass:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"  Error message: {str(e)}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"  Input shapes:") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    edge_ncs: {edge_ncs.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    edge_pos: {edge_pos.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    edge_mask: {edge_mask.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    surf_ncs: {surf_ncs.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    surf_pos: {surf_pos.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    vertex_pos: {vertex_pos.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合所有特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_features   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 只使用位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_p_embeds.reshape(B, F*E, -1)  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds = torch.cat([ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                edge_features,  # [B, max_face*max_edge, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                surf_p_embeds   # [B, max_face, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            ], dim=1)  # [B, max_face*(max_edge+1), embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 5. 处理掩码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if edge_mask is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 扩展掩码以匹配特征维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_mask = edge_mask.reshape(B, -1)  # [B, max_face*max_edge] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool)  # [B, max_face] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, max_face*(max_edge+1)] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            mask = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 6. Transformer处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output.transpose(0, 1)  # 确保输出维度为 [B, seq_len, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class SDFTransformer(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF Transformer编码器""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, embed_dim: int = 768, num_layers: int = 6): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, embed_dim: int = 192, num_layers: int = 6):  # 改为192以匹配BRepFeatureEmbedder | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 添加位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.02) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 修改Transformer层配置 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        layer = nn.TransformerEncoderLayer( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            d_model=embed_dim, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nhead=8, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dim_feedforward=1024, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nhead=4,  # 减少头数,使每个头的维度更大 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dim_feedforward=embed_dim * 2,  # 减小FFN维度 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            dropout=0.1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            batch_first=True, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm_first=False  # 修改这里:设置为False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm_first=True,  # 使用Pre-LN结构 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            activation=F.gelu | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 添加梯度缩放因子 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.attention_scale = math.sqrt(embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 4. 自定义初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        def _init_weights(module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if isinstance(module, nn.Linear): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 使用较小的初始化范围 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nn.init.xavier_uniform_(module.weight, gain=0.1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if module.bias is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            elif isinstance(module, nn.LayerNorm): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nn.init.ones_(module.weight) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nn.init.zeros_(module.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            elif isinstance(module, nn.MultiheadAttention): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                # 特别处理注意力层 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                in_proj_weight = module.in_proj_weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                out_proj_weight = module.out_proj.weight | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nn.init.xavier_uniform_(in_proj_weight, gain=0.1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                nn.init.xavier_uniform_(out_proj_weight, gain=0.1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if module.in_proj_bias is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    nn.init.zeros_(module.in_proj_bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if module.out_proj.bias is not None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    nn.init.zeros_(module.out_proj.bias) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer = nn.TransformerEncoder( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            layer,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            num_layers, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            norm=nn.LayerNorm(embed_dim) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer = nn.TransformerEncoder(layer, num_layers) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 应用初始化 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.transformer.apply(_init_weights) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 5. 添加残差缩放 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, x, mask=None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.transformer(x, src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 添加位置编码 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        seq_len = x.size(1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x + self.pos_embedding[:, :seq_len, :] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 缩放注意力分数 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x * self.attention_scale | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 前向传播 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for layer in self.transformer.layers: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 手动添加残差连接和缩放 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            identity = x | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x = layer(x, src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x = identity + x * self.residual_scale | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.transformer.norm(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class SDFHead(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF预测头""" | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |