From a586f2e5345cb6af968912ceece966e33d61020e Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 19 Nov 2024 03:50:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20sdf=5Fxyz=E5=92=8Csdf=5Fvalue=20?= =?UTF-8?q?=E5=88=86=E5=BC=80=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/data.py | 9 +++++++-- brep2sdf/train.py | 32 ++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 8f7798e..0e1b322 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -99,7 +99,11 @@ class BRepSDFDataset(Dataset): max_edge=self.max_edge, bbox_scaled=self.bbox_scaled ) - + # 打印数据形状 + logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:") + for value in brep_features: + if isinstance(value, torch.Tensor): + logger.debug(f" {value.shape}") # 检查返回值的类型和数量 if not isinstance(brep_features, tuple): logger.error(f"process_brep_data returned {type(brep_features)}, expected tuple") @@ -127,7 +131,8 @@ class BRepSDFDataset(Dataset): 'surf_ncs': surf_ncs, # [max_face, 100, 3] 'surf_pos': surf_pos, # [max_face, 6] 'vertex_pos': vertex_pos, # [max_face, max_edge, 6] - 'sdf': sdf_data # [N, 4] + 'points': sdf_data[:, :3], # [num_queries, 3] 所有点的xyz坐标 + 'sdf': sdf_data[:, 3:] # [num_queries, 1] 所有点的sdf值 } except Exception as e: diff --git a/brep2sdf/train.py b/brep2sdf/train.py index ffcb00f..924ccb6 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -118,21 +118,23 @@ class Trainer: surf_pos = batch['surf_pos'].to(self.device) edge_pos = batch['edge_pos'].to(self.device) vertex_pos = batch['vertex_pos'].to(self.device) - sdf = batch['sdf'].to(self.device) + edge_mask = batch['edge_mask'].to(self.device) + points = batch['points'].to(self.device) + gt_sdf = batch['sdf'].to(self.device) # 前向传播 pred_sdf = self.model( - surf_ncs, edge_ncs, - surf_pos, edge_pos, - vertex_pos, - sdf[:, :3] # 只使用点坐标,不包括SDF值 + surf_ncs=surf_ncs, edge_ncs=edge_ncs, + surf_pos=surf_pos, edge_pos=edge_pos, + vertex_pos=vertex_pos, edge_mask=edge_mask, + query_points=points # 只使用点坐标,不包括SDF值 ) # 计算损失 loss = sdf_loss( pred_sdf, - sdf[:, 3], # 使用SDF值 - sdf[:, :3], # 使用点坐标 + gt_sdf, + points, grad_weight=self.config.train.grad_weight ) @@ -175,21 +177,23 @@ class Trainer: surf_pos = batch['surf_pos'].to(self.device) edge_pos = batch['edge_pos'].to(self.device) vertex_pos = batch['vertex_pos'].to(self.device) - sdf = batch['sdf'].to(self.device) + edge_mask = batch['edge_mask'].to(self.device) + points = batch['points'].to(self.device) + gt_sdf = batch['sdf'].to(self.device) # 前向传播 pred_sdf = self.model( - surf_ncs, edge_ncs, - surf_pos, edge_pos, - vertex_pos, - sdf[:, :3] + surf_ncs=surf_ncs, edge_ncs=edge_ncs, + surf_pos=surf_pos, edge_pos=edge_pos, + vertex_pos=vertex_pos, edge_mask=edge_mask, + query_points=points ) # 计算损失 loss = sdf_loss( pred_sdf, - sdf[:, 3], - sdf[:, :3], + gt_sdf, + points, grad_weight=self.config.train.grad_weight )