| 
						
						
							
								
							
						
						
					 | 
					@ -221,17 +221,9 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) | 
					 | 
					 | 
					        self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 添加 transformer 初始化 | 
					 | 
					 | 
					        # 添加 transformer 初始化 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        self.transformer = nn.TransformerEncoder( | 
					 | 
					 | 
					        layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8, norm_first=True, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            encoder_layer=nn.TransformerEncoderLayer( | 
					 | 
					 | 
					                                                   dim_feedforward=1024, dropout=0.1) | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					                d_model=self.embed_dim, | 
					 | 
					 | 
					        self.net = nn.TransformerEncoder(layer, 6, nn.LayerNorm(self.embed_dim)) | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					                nhead=8,  # 注意力头数,通常是embed_dim的因子 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                dim_feedforward=4*self.embed_dim,  # 前馈网络维度,通常是embed_dim的4倍 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                dropout=0.1, | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                activation='gelu', | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                batch_first=False  # 因为我们用了transpose(0,1) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            ), | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            num_layers=6  # transformer层数 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        ) | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): | 
					 | 
					 | 
					    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -245,27 +237,48 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        B, F, E = edge_pos.shape[:3] | 
					 | 
					 | 
					        B, F, E = edge_pos.shape[:3] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 检查输入张量 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        input_tensors = { | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'edge_ncs': edge_ncs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'edge_pos': edge_pos, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'surf_ncs': surf_ncs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'surf_pos': surf_pos, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            'vertex_pos': vertex_pos | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("\n=== 输入张量检查 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        for name, tensor in input_tensors.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            print_tensor_stats(name, tensor) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 1. 处理顶点特征 | 
					 | 
					 | 
					        # 1. 处理顶点特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        vertex_embed = self.vertp_embed(vertex_pos[..., :3])  # [B, F, E, 2, embed_dim] | 
					 | 
					 | 
					        vertex_embed = self.vertp_embed(vertex_pos[..., :3])  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('vertex_embed', vertex_embed) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        vertex_embed = self.vertex_proj(vertex_embed)  # [B, F, E, 2, embed_dim] | 
					 | 
					 | 
					        vertex_embed = self.vertex_proj(vertex_embed)  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('vertex_embed(after proj)', vertex_embed) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 2. 处理边特征 | 
					 | 
					 | 
					        # 2. 处理边特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('edge_embeds', edge_embeds) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
					 | 
					 | 
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('edge_p_embeds', edge_p_embeds) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 3. 处理面特征 | 
					 | 
					 | 
					        # 3. 处理面特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B, F, embed_dim] | 
					 | 
					 | 
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B, F, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('surf_embeds', surf_embeds) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, F, embed_dim] | 
					 | 
					 | 
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, F, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('surf_p_embeds', surf_p_embeds) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					         | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        # 4. 组合特征 | 
					 | 
					 | 
					        # 4. 组合特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if self.use_cf: | 
					 | 
					 | 
					        if self.use_cf: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 组合边特征 | 
					 | 
					 | 
					            # 组合边特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            edge_features = edge_embeds + edge_p_embeds + vertex_embed  # [B, F, E, embed_dim] | 
					 | 
					 | 
					            edge_features = edge_embeds + edge_p_embeds + vertex_embed  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, F*E, embed_dim] | 
					 | 
					 | 
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            print_tensor_stats('edge_features', edge_features) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 组合面特征 | 
					 | 
					 | 
					            # 组合面特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            surf_features = surf_embeds + surf_p_embeds  # [B, F, embed_dim] | 
					 | 
					 | 
					            surf_features = surf_embeds + surf_p_embeds  # [B, F, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            print_tensor_stats('surf_features', surf_features) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					             | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            # 拼接所有特征 | 
					 | 
					 | 
					            # 拼接所有特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            embeds = torch.cat([ | 
					 | 
					 | 
					            embeds = torch.cat([ | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -287,9 +300,11 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, F*E+F] | 
					 | 
					 | 
					            mask = torch.cat([edge_mask, surf_mask], dim=1)  # [B, F*E+F] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        else: | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            mask = None | 
					 | 
					 | 
					            mask = None | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					         | 
					 | 
					 | 
					        logger.debug(f"embeds shape: {embeds.shape}") | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					        # 6. Transformer处理 | 
					 | 
					 | 
					        # 6. Transformer处理 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					        output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
					 | 
					 | 
					        output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        print_tensor_stats('output', output) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.debug(f"output shape: {output.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return output.transpose(0, 1)  # [B, F*E+F, embed_dim] | 
					 | 
					 | 
					        return output.transpose(0, 1)  # [B, F*E+F, embed_dim] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class SDFTransformer(nn.Module): | 
					 | 
					 | 
					class SDFTransformer(nn.Module): | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -501,6 +516,22 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.error(f"    query_points: {query_points.shape}") | 
					 | 
					 | 
					            logger.error(f"    query_points: {query_points.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            raise | 
					 | 
					 | 
					            raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					def print_tensor_stats(name: str, tensor: torch.Tensor): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    """打印张量的统计信息""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"\n=== {name} 统计信息 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  shape: {tensor.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  norm: {tensor.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  mean: {tensor.mean().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  std: {tensor.std().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  min: {tensor.min().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  max: {tensor.max().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"  requires_grad: {tensor.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    if tensor.requires_grad: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if not tensor.grad_fn: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.warning(f"⚠️ {name} requires_grad=False!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
					 | 
					 | 
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """SDF损失函数""" | 
					 | 
					 | 
					    """SDF损失函数""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    # 确保points需要梯度 | 
					 | 
					 | 
					    # 确保points需要梯度 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -580,5 +611,107 @@ def main(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        logger.error(f"Error during forward pass: {str(e)}") | 
					 | 
					 | 
					        logger.error(f"Error during forward pass: {str(e)}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        raise | 
					 | 
					 | 
					        raise | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					def test_brep_embedder(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    """测试BRepFeatureEmbedder的参数初始化和梯度流动""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 1. 初始化配置和模型 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    config = get_default_config() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    embedder = BRepFeatureEmbedder(config) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 2. 生成测试数据 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    B, F, E = 2, 8, 16  # batch_size, max_face, max_edge | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    test_data = { | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'edge_ncs': torch.randn(B, F, E, config.model.num_edge_points, 3, requires_grad=True), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'edge_pos': torch.randn(B, F, E, 6, requires_grad=True), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'edge_mask': torch.ones(B, F, E, dtype=torch.bool), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'surf_ncs': torch.randn(B, F, config.model.num_surf_points, 3, requires_grad=True), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'surf_pos': torch.randn(B, F, 6, requires_grad=True), | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        'vertex_pos': torch.randn(B, F, E, 2, 3, requires_grad=True) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    } | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 3. 检查初始参数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== 初始参数检查 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    for name, param in embedder.named_parameters(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"\n{name}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  shape: {param.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  requires_grad: {param.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  norm: {param.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  mean: {param.mean().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  std: {param.std().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 4. 前向传播 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== 前向传播 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    outputs = embedder(**test_data) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info(f"Output shape: {outputs.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 5. 检查中间特征 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    def check_tensor(tensor, name): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"\n{name}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  shape: {tensor.shape}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  requires_grad: {tensor.requires_grad}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  has_grad_fn: {tensor.grad_fn is not None}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if tensor.grad_fn: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  grad_fn: {type(tensor.grad_fn).__name__}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  norm: {tensor.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  mean: {tensor.mean().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"  std: {tensor.std().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 6. 反向传播 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== 反向传播 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    loss = outputs.mean()  # 简单的损失函数 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    loss.backward() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 7. 检查梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== 梯度检查 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 7.1 检查输入梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    for name, tensor in test_data.items(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if tensor.requires_grad: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"\n{name} gradient:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if tensor.grad is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  grad norm: {tensor.grad.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  grad mean: {tensor.grad.mean().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info(f"  grad std: {tensor.grad.std().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.info("  No gradient!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					     | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 7.2 检查模型参数梯度 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== 模型参数梯度 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    for name, param in embedder.named_parameters(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"\n{name}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if param.grad is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  grad norm: {param.grad.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  grad mean: {param.grad.mean().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"  grad std: {param.grad.std().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            # 检查是否有任何梯度为NaN或inf | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if torch.isnan(param.grad).any(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.warning("  Contains NaN gradients!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if torch.isinf(param.grad).any(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                logger.warning("  Contains Inf gradients!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.warning("  No gradient!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    # 8. 特别检查transformer层 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    logger.info("\n=== Transformer层检查 ===") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    for i, layer in enumerate(embedder.net.layers): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"\nTransformer Layer {i}:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 检查自注意力层 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("  Self Attention:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    in_proj_weight norm: {layer.self_attn.in_proj_weight.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    in_proj_bias norm: {layer.self_attn.in_proj_bias.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    out_proj.weight norm: {layer.self_attn.out_proj.weight.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    out_proj.bias norm: {layer.self_attn.out_proj.bias.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					         | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if layer.self_attn.in_proj_weight.grad is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.info(f"    in_proj_weight grad norm: {layer.self_attn.in_proj_weight.grad.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            logger.warning("    in_proj_weight has no gradient!") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					             | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        # 检查LayerNorm层 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info("  LayerNorm:") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    norm1.weight norm: {layer.norm1.weight.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    norm1.bias norm: {layer.norm1.bias.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    norm2.weight norm: {layer.norm2.weight.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.info(f"    norm2.bias norm: {layer.norm2.bias.norm().item():.6f}") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					if __name__ == "__main__": | 
					 | 
					 | 
					if __name__ == "__main__": | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    main() | 
					 | 
					 | 
					    test_brep_embedder() |