| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -248,37 +248,37 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.info("\n=== 输入张量检查 ===") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for name, tensor in input_tensors.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            print_tensor_stats(name, tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats(name, tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 1. 处理顶点特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertp_embed(vertex_pos[..., :3])  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('vertex_embed', vertex_embed) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('vertex_embed', vertex_embed) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = self.vertex_proj(vertex_embed)  # [B, F, E, 2, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('vertex_embed(after proj)', vertex_embed) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        vertex_embed = vertex_embed.mean(dim=3)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 2. 处理边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_embeds = self.edgez_embed(edge_ncs)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('edge_embeds', edge_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('edge_embeds', edge_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        edge_p_embeds = self.edgep_embed(edge_pos)  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('edge_p_embeds', edge_p_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 3. 处理面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_embeds = self.surfz_embed(surf_ncs)  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('surf_embeds', surf_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('surf_embeds', surf_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        surf_p_embeds = self.surfp_embed(surf_pos)  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('surf_p_embeds', surf_p_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 4. 组合特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self.use_cf: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合边特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_embeds + edge_p_embeds + vertex_embed  # [B, F, E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            edge_features = edge_features.reshape(B, F*E, -1)  # [B, F*E, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            print_tensor_stats('edge_features', edge_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats('edge_features', edge_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 组合面特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            surf_features = surf_embeds + surf_p_embeds  # [B, F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            print_tensor_stats('surf_features', surf_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.print_tensor_stats('surf_features', surf_features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            # 拼接所有特征 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            embeds = torch.cat([ | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -303,7 +303,7 @@ class BRepFeatureEmbedder(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug(f"embeds shape: {embeds.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # 6. Transformer处理 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print_tensor_stats('output', output) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.print_tensor_stats('output', output) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.debug(f"output shape: {output.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return output.transpose(0, 1)  # [B, F*E+F, embed_dim] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -516,22 +516,6 @@ class BRepToSDF(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.error(f"    query_points: {query_points.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def print_tensor_stats(name: str, tensor: torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """打印张量的统计信息""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"\n=== {name} 统计信息 ===") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  shape: {tensor.shape}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  norm: {tensor.norm().item():.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  mean: {tensor.mean().item():.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  std: {tensor.std().item():.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  min: {tensor.min().item():.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  max: {tensor.max().item():.6f}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"  requires_grad: {tensor.requires_grad}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if tensor.requires_grad: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not tensor.grad_fn: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        logger.warning(f"⚠️ {name} requires_grad=False!") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """SDF损失函数""" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # 确保points需要梯度 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |