Browse Source

tofix: 概率性有梯度

main
mckay 4 months ago
parent
commit
596a7b6565
  1. 20
      brep2sdf/networks/encoder.py

20
brep2sdf/networks/encoder.py

@ -251,7 +251,6 @@ class BRepFeatureEmbedder(nn.Module):
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征
logger.info(f"edge_ncs shape: {edge_ncs.shape}")
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
@ -441,6 +440,13 @@ class BRepToSDF(nn.Module):
"""
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries
#logger.info(f"query_points requires_grad: {query_points.requires_grad}")
#logger.info(f"edge_ncs requires_grad: {edge_ncs.requires_grad}")
#logger.info(f"edge_pos requires_grad: {edge_pos.requires_grad}")
#logger.info(f"edge_mask requires_grad: {edge_mask.requires_grad}")
#logger.info(f"surf_ncs requires_grad: {surf_ncs.requires_grad}")
#logger.info(f"surf_pos requires_grad: {surf_pos.requires_grad}")
try:
# 确保query_points需要梯度
if not query_points.requires_grad:
@ -544,13 +550,13 @@ def main():
# 生成测试数据
test_data = {
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3),
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6),
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3, requires_grad=True),
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6, requires_grad=True),
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool),
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3),
'surf_pos': torch.randn(batch_size, max_face, 6),
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3),
'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3, requires_grad=True),
'surf_pos': torch.randn(batch_size, max_face, 6, requires_grad=True),
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3, requires_grad=True),
'query_points': torch.randn(batch_size, 1000, 3, requires_grad=True) # 1000个查询点
}
# 打印输入数据形状

Loading…
Cancel
Save