|
|
@ -251,7 +251,6 @@ class BRepFeatureEmbedder(nn.Module): |
|
|
|
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] |
|
|
|
|
|
|
|
# 2. 处理边特征 |
|
|
|
logger.info(f"edge_ncs shape: {edge_ncs.shape}") |
|
|
|
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] |
|
|
|
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] |
|
|
|
|
|
|
@ -441,6 +440,13 @@ class BRepToSDF(nn.Module): |
|
|
|
""" |
|
|
|
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries |
|
|
|
|
|
|
|
#logger.info(f"query_points requires_grad: {query_points.requires_grad}") |
|
|
|
#logger.info(f"edge_ncs requires_grad: {edge_ncs.requires_grad}") |
|
|
|
#logger.info(f"edge_pos requires_grad: {edge_pos.requires_grad}") |
|
|
|
#logger.info(f"edge_mask requires_grad: {edge_mask.requires_grad}") |
|
|
|
#logger.info(f"surf_ncs requires_grad: {surf_ncs.requires_grad}") |
|
|
|
#logger.info(f"surf_pos requires_grad: {surf_pos.requires_grad}") |
|
|
|
|
|
|
|
try: |
|
|
|
# 确保query_points需要梯度 |
|
|
|
if not query_points.requires_grad: |
|
|
@ -544,13 +550,13 @@ def main(): |
|
|
|
|
|
|
|
# 生成测试数据 |
|
|
|
test_data = { |
|
|
|
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), |
|
|
|
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), |
|
|
|
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3, requires_grad=True), |
|
|
|
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6, requires_grad=True), |
|
|
|
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), |
|
|
|
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), |
|
|
|
'surf_pos': torch.randn(batch_size, max_face, 6), |
|
|
|
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), |
|
|
|
'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 |
|
|
|
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3, requires_grad=True), |
|
|
|
'surf_pos': torch.randn(batch_size, max_face, 6, requires_grad=True), |
|
|
|
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3, requires_grad=True), |
|
|
|
'query_points': torch.randn(batch_size, 1000, 3, requires_grad=True) # 1000个查询点 |
|
|
|
} |
|
|
|
|
|
|
|
# 打印输入数据形状 |
|
|
|