|
|
@ -221,17 +221,9 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
|
|
# 添加 transformer 初始化 |
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
|
encoder_layer=nn.TransformerEncoderLayer( |
|
|
|
d_model=self.embed_dim, |
|
|
|
nhead=8, # 注意力头数,通常是embed_dim的因子 |
|
|
|
dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍 |
|
|
|
dropout=0.1, |
|
|
|
activation='gelu', |
|
|
|
batch_first=False # 因为我们用了transpose(0,1) |
|
|
|
), |
|
|
|
num_layers=6 # transformer层数 |
|
|
|
) |
|
|
|
layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=8, norm_first=True, |
|
|
|
dim_feedforward=1024, dropout=0.1) |
|
|
|
self.net = nn.TransformerEncoder(layer, 6, nn.LayerNorm(self.embed_dim)) |
|
|
|
|
|
|
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): |
|
|
|
""" |
|
|
@ -245,27 +237,48 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
""" |
|
|
|
B, F, E = edge_pos.shape[:3] |
|
|
|
|
|
|
|
# 检查输入张量 |
|
|
|
input_tensors = { |
|
|
|
'edge_ncs': edge_ncs, |
|
|
|
'edge_pos': edge_pos, |
|
|
|
'surf_ncs': surf_ncs, |
|
|
|
'surf_pos': surf_pos, |
|
|
|
'vertex_pos': vertex_pos |
|
|
|
} |
|
|
|
|
|
|
|
logger.info("\n=== 输入张量检查 ===") |
|
|
|
for name, tensor in input_tensors.items(): |
|
|
|
print_tensor_stats(name, tensor) |
|
|
|
|
|
|
|
# 1. 处理顶点特征 |
|
|
|
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] |
|
|
|
print_tensor_stats('vertex_embed', vertex_embed) |
|
|
|
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] |
|
|
|
print_tensor_stats('vertex_embed(after proj)', vertex_embed) |
|
|
|
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理边特征 |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] |
|
|
|
print_tensor_stats('edge_embeds', edge_embeds) |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] |
|
|
|
print_tensor_stats('edge_p_embeds', edge_p_embeds) |
|
|
|
|
|
|
|
# 3. 处理面特征 |
|
|
|
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_embeds', surf_embeds) |
|
|
|
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_p_embeds', surf_p_embeds) |
|
|
|
|
|
|
|
# 4. 组合特征 |
|
|
|
if self.use_cf: |
|
|
|
# 组合边特征 |
|
|
|
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] |
|
|
|
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] |
|
|
|
print_tensor_stats('edge_features', edge_features) |
|
|
|
|
|
|
|
# 组合面特征 |
|
|
|
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] |
|
|
|
print_tensor_stats('surf_features', surf_features) |
|
|
|
|
|
|
|
# 拼接所有特征 |
|
|
|
embeds = torch.cat([ |
|
|
@ -287,9 +300,11 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] |
|
|
|
else: |
|
|
|
mask = None |
|
|
|
|
|
|
|
logger.debug(f"embeds shape: {embeds.shape}") |
|
|
|
# 6. Transformer处理 |
|
|
|
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) |
|
|
|
print_tensor_stats('output', output) |
|
|
|
logger.debug(f"output shape: {output.shape}") |
|
|
|
return output.transpose(0, 1) # [B, F*E+F, embed_dim] |
|
|
|
|
|
|
|
class SDFTransformer(nn.Module): |
|
|
@ -501,6 +516,22 @@ class BRepToSDF(nn.Module): |
|
|
|
logger.error(f" query_points: {query_points.shape}") |
|
|
|
raise |
|
|
|
|
|
|
|
def print_tensor_stats(name: str, tensor: torch.Tensor): |
|
|
|
"""打印张量的统计信息""" |
|
|
|
logger.info(f"\n=== {name} 统计信息 ===") |
|
|
|
logger.info(f" shape: {tensor.shape}") |
|
|
|
logger.info(f" norm: {tensor.norm().item():.6f}") |
|
|
|
logger.info(f" mean: {tensor.mean().item():.6f}") |
|
|
|
logger.info(f" std: {tensor.std().item():.6f}") |
|
|
|
logger.info(f" min: {tensor.min().item():.6f}") |
|
|
|
logger.info(f" max: {tensor.max().item():.6f}") |
|
|
|
logger.info(f" requires_grad: {tensor.requires_grad}") |
|
|
|
if tensor.requires_grad: |
|
|
|
if not tensor.grad_fn: |
|
|
|
logger.warning(f"⚠️ {name} requires_grad=True 但没有梯度函数!") |
|
|
|
else: |
|
|
|
logger.warning(f"⚠️ {name} requires_grad=False!") |
|
|
|
|
|
|
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|
|
|
"""SDF损失函数""" |
|
|
|
# 确保points需要梯度 |
|
|
@ -580,5 +611,107 @@ def main(): |
|
|
|
logger.error(f"Error during forward pass: {str(e)}") |
|
|
|
raise |
|
|
|
|
|
|
|
def test_brep_embedder(): |
|
|
|
"""测试BRepFeatureEmbedder的参数初始化和梯度流动""" |
|
|
|
# 1. 初始化配置和模型 |
|
|
|
config = get_default_config() |
|
|
|
embedder = BRepFeatureEmbedder(config) |
|
|
|
|
|
|
|
# 2. 生成测试数据 |
|
|
|
B, F, E = 2, 8, 16 # batch_size, max_face, max_edge |
|
|
|
test_data = { |
|
|
|
'edge_ncs': torch.randn(B, F, E, config.model.num_edge_points, 3, requires_grad=True), |
|
|
|
'edge_pos': torch.randn(B, F, E, 6, requires_grad=True), |
|
|
|
'edge_mask': torch.ones(B, F, E, dtype=torch.bool), |
|
|
|
'surf_ncs': torch.randn(B, F, config.model.num_surf_points, 3, requires_grad=True), |
|
|
|
'surf_pos': torch.randn(B, F, 6, requires_grad=True), |
|
|
|
'vertex_pos': torch.randn(B, F, E, 2, 3, requires_grad=True) |
|
|
|
} |
|
|
|
|
|
|
|
# 3. 检查初始参数 |
|
|
|
logger.info("\n=== 初始参数检查 ===") |
|
|
|
for name, param in embedder.named_parameters(): |
|
|
|
logger.info(f"\n{name}:") |
|
|
|
logger.info(f" shape: {param.shape}") |
|
|
|
logger.info(f" requires_grad: {param.requires_grad}") |
|
|
|
logger.info(f" norm: {param.norm().item():.6f}") |
|
|
|
logger.info(f" mean: {param.mean().item():.6f}") |
|
|
|
logger.info(f" std: {param.std().item():.6f}") |
|
|
|
|
|
|
|
# 4. 前向传播 |
|
|
|
logger.info("\n=== 前向传播 ===") |
|
|
|
outputs = embedder(**test_data) |
|
|
|
logger.info(f"Output shape: {outputs.shape}") |
|
|
|
|
|
|
|
# 5. 检查中间特征 |
|
|
|
def check_tensor(tensor, name): |
|
|
|
logger.info(f"\n{name}:") |
|
|
|
logger.info(f" shape: {tensor.shape}") |
|
|
|
logger.info(f" requires_grad: {tensor.requires_grad}") |
|
|
|
logger.info(f" has_grad_fn: {tensor.grad_fn is not None}") |
|
|
|
if tensor.grad_fn: |
|
|
|
logger.info(f" grad_fn: {type(tensor.grad_fn).__name__}") |
|
|
|
logger.info(f" norm: {tensor.norm().item():.6f}") |
|
|
|
logger.info(f" mean: {tensor.mean().item():.6f}") |
|
|
|
logger.info(f" std: {tensor.std().item():.6f}") |
|
|
|
|
|
|
|
# 6. 反向传播 |
|
|
|
logger.info("\n=== 反向传播 ===") |
|
|
|
loss = outputs.mean() # 简单的损失函数 |
|
|
|
loss.backward() |
|
|
|
|
|
|
|
# 7. 检查梯度 |
|
|
|
logger.info("\n=== 梯度检查 ===") |
|
|
|
|
|
|
|
# 7.1 检查输入梯度 |
|
|
|
for name, tensor in test_data.items(): |
|
|
|
if tensor.requires_grad: |
|
|
|
logger.info(f"\n{name} gradient:") |
|
|
|
if tensor.grad is not None: |
|
|
|
logger.info(f" grad norm: {tensor.grad.norm().item():.6f}") |
|
|
|
logger.info(f" grad mean: {tensor.grad.mean().item():.6f}") |
|
|
|
logger.info(f" grad std: {tensor.grad.std().item():.6f}") |
|
|
|
else: |
|
|
|
logger.info(" No gradient!") |
|
|
|
|
|
|
|
# 7.2 检查模型参数梯度 |
|
|
|
logger.info("\n=== 模型参数梯度 ===") |
|
|
|
for name, param in embedder.named_parameters(): |
|
|
|
logger.info(f"\n{name}:") |
|
|
|
if param.grad is not None: |
|
|
|
logger.info(f" grad norm: {param.grad.norm().item():.6f}") |
|
|
|
logger.info(f" grad mean: {param.grad.mean().item():.6f}") |
|
|
|
logger.info(f" grad std: {param.grad.std().item():.6f}") |
|
|
|
# 检查是否有任何梯度为NaN或inf |
|
|
|
if torch.isnan(param.grad).any(): |
|
|
|
logger.warning(" Contains NaN gradients!") |
|
|
|
if torch.isinf(param.grad).any(): |
|
|
|
logger.warning(" Contains Inf gradients!") |
|
|
|
else: |
|
|
|
logger.warning(" No gradient!") |
|
|
|
|
|
|
|
# 8. 特别检查transformer层 |
|
|
|
logger.info("\n=== Transformer层检查 ===") |
|
|
|
for i, layer in enumerate(embedder.net.layers): |
|
|
|
logger.info(f"\nTransformer Layer {i}:") |
|
|
|
# 检查自注意力层 |
|
|
|
logger.info(" Self Attention:") |
|
|
|
logger.info(f" in_proj_weight norm: {layer.self_attn.in_proj_weight.norm().item():.6f}") |
|
|
|
logger.info(f" in_proj_bias norm: {layer.self_attn.in_proj_bias.norm().item():.6f}") |
|
|
|
logger.info(f" out_proj.weight norm: {layer.self_attn.out_proj.weight.norm().item():.6f}") |
|
|
|
logger.info(f" out_proj.bias norm: {layer.self_attn.out_proj.bias.norm().item():.6f}") |
|
|
|
|
|
|
|
if layer.self_attn.in_proj_weight.grad is not None: |
|
|
|
logger.info(f" in_proj_weight grad norm: {layer.self_attn.in_proj_weight.grad.norm().item():.6f}") |
|
|
|
else: |
|
|
|
logger.warning(" in_proj_weight has no gradient!") |
|
|
|
|
|
|
|
# 检查LayerNorm层 |
|
|
|
logger.info(" LayerNorm:") |
|
|
|
logger.info(f" norm1.weight norm: {layer.norm1.weight.norm().item():.6f}") |
|
|
|
logger.info(f" norm1.bias norm: {layer.norm1.bias.norm().item():.6f}") |
|
|
|
logger.info(f" norm2.weight norm: {layer.norm2.weight.norm().item():.6f}") |
|
|
|
logger.info(f" norm2.bias norm: {layer.norm2.bias.norm().item():.6f}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|
|
test_brep_embedder() |