Browse Source

fix: sdf_xyz和sdf_value 分开加载

main
mckay 4 months ago
parent
commit
a586f2e534
  1. 9
      brep2sdf/data/data.py
  2. 32
      brep2sdf/train.py

9
brep2sdf/data/data.py

@ -99,7 +99,11 @@ class BRepSDFDataset(Dataset):
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
if isinstance(value, torch.Tensor):
logger.debug(f" {value.shape}")
# 检查返回值的类型和数量
if not isinstance(brep_features, tuple):
logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple")
@ -127,7 +131,8 @@ class BRepSDFDataset(Dataset):
'surf_ncs': surf_ncs, # [max_face, 100, 3]
'surf_pos': surf_pos, # [max_face, 6]
'vertex_pos': vertex_pos, # [max_face, max_edge, 6]
'sdf': sdf_data # [N, 4]
'points': sdf_data[:, :3], # [num_queries, 3] 所有点的xyz坐标
'sdf': sdf_data[:, 3:] # [num_queries, 1] 所有点的sdf值
}
except Exception as e:

32
brep2sdf/train.py

@ -118,21 +118,23 @@ class Trainer:
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
sdf = batch['sdf'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3] # 只使用点坐标,不包括SDF值
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points # 只使用点坐标,不包括SDF值
)
# 计算损失
loss = sdf_loss(
pred_sdf,
sdf[:, 3], # 使用SDF值
sdf[:, :3], # 使用点坐标
gt_sdf,
points,
grad_weight=self.config.train.grad_weight
)
@ -175,21 +177,23 @@ class Trainer:
surf_pos = batch['surf_pos'].to(self.device)
edge_pos = batch['edge_pos'].to(self.device)
vertex_pos = batch['vertex_pos'].to(self.device)
sdf = batch['sdf'].to(self.device)
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
pred_sdf = self.model(
surf_ncs, edge_ncs,
surf_pos, edge_pos,
vertex_pos,
sdf[:, :3]
surf_ncs=surf_ncs, edge_ncs=edge_ncs,
surf_pos=surf_pos, edge_pos=edge_pos,
vertex_pos=vertex_pos, edge_mask=edge_mask,
query_points=points
)
# 计算损失
loss = sdf_loss(
pred_sdf,
sdf[:, 3],
sdf[:, :3],
gt_sdf,
points,
grad_weight=self.config.train.grad_weight
)

Loading…
Cancel
Save