|
|
@ -248,37 +248,37 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
|
|
|
|
logger.info("\n=== 输入张量检查 ===") |
|
|
|
for name, tensor in input_tensors.items(): |
|
|
|
print_tensor_stats(name, tensor) |
|
|
|
logger.print_tensor_stats(name, tensor) |
|
|
|
|
|
|
|
# 1. 处理顶点特征 |
|
|
|
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] |
|
|
|
print_tensor_stats('vertex_embed', vertex_embed) |
|
|
|
logger.print_tensor_stats('vertex_embed', vertex_embed) |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] |
|
|
|
print_tensor_stats('vertex_embed(after proj)', vertex_embed) |
|
|
|
logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) |
|
|
|
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理边特征 |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] |
|
|
|
print_tensor_stats('edge_embeds', edge_embeds) |
|
|
|
logger.print_tensor_stats('edge_embeds', edge_embeds) |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] |
|
|
|
print_tensor_stats('edge_p_embeds', edge_p_embeds) |
|
|
|
logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) |
|
|
|
|
|
|
|
# 3. 处理面特征 |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_embeds', surf_embeds) |
|
|
|
logger.print_tensor_stats('surf_embeds', surf_embeds) |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_p_embeds', surf_p_embeds) |
|
|
|
logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 组合边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] |
|
|
|
print_tensor_stats('edge_features', edge_features) |
|
|
|
logger.print_tensor_stats('edge_features', edge_features) |
|
|
|
|
|
|
|
# 组合面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_features', surf_features) |
|
|
|
logger.print_tensor_stats('surf_features', surf_features) |
|
|
|
|
|
|
|
# 拼接所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
@ -303,7 +303,7 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
logger.debug(f"embeds shape: {embeds.shape}") |
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
print_tensor_stats('output', output) |
|
|
|
logger.print_tensor_stats('output', output) |
|
|
|
logger.debug(f"output shape: {output.shape}") |
|
|
|
return output.transpose(0, 1) # [B, F*E+F, embed_dim] |
|
|
|
|
|
|
@ -516,22 +516,6 @@ class BRepToSDF(nn.Module): |
|
|
|
logger.error(f" query_points: {query_points.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
def print_tensor_stats(name: str, tensor: torch.Tensor): |
|
|
|
"""打印张量的统计信息""" |
|
|
|
logger.info(f"\n=== {name} 统计信息 ===") |
|
|
|
logger.info(f" shape: {tensor.shape}") |
|
|
|
logger.info(f" norm: {tensor.norm().item():.6f}") |
|
|
|
logger.info(f" mean: {tensor.mean().item():.6f}") |
|
|
|
logger.info(f" std: {tensor.std().item():.6f}") |
|
|
|
logger.info(f" min: {tensor.min().item():.6f}") |
|
|
|
logger.info(f" max: {tensor.max().item():.6f}") |
|
|
|
logger.info(f" requires_grad: {tensor.requires_grad}") |
|
|
|
if tensor.requires_grad: |
|
|
|
if not tensor.grad_fn: |
|
|
|
logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") |
|
|
|
else: |
|
|
|
logger.warning(f"⚠️ {name} requires_grad=False!") |
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
"""SDF损失函数""" |
|
|
|
# 确保points需要梯度 |
|
|
|