diff --git a/brep2sdf/data/pre_process_by_mesh.py b/brep2sdf/data/pre_process_by_mesh.py index 1fe95af..3a0f063 100644 --- a/brep2sdf/data/pre_process_by_mesh.py +++ b/brep2sdf/data/pre_process_by_mesh.py @@ -433,7 +433,7 @@ def parse_solid(step_path,sample_normal_vector=False,sample_sdf_points=False): data['sampled_points_normals_sdf'] = sample_sdf_points_and_normals( trimesh_mesh_ncs=trimesh_mesh_ncs, surf_bbox_ncs=data['surf_bbox_ncs'], - num_sdf_samples=4096, # <-- 传递固定数量 + num_sdf_samples=50000, # <-- 传递固定数量 sdf_sampling_std_dev=0.0001 ) else: diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 14ac897..7133e92 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -39,7 +39,7 @@ parser.add_argument( help='强制重新进行数据预处理,忽略缓存或已有结果' ) parser.add_argument( - '--resume-checkpoint-path', + '--resume-checkpoint-path', '-r', type=str, default=None, help='从指定的checkpoint文件继续训练' @@ -99,6 +99,7 @@ class Trainer: ) # 合并表面点数据和采样点数据 self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0) + self.train_surf_ncs = sampled_sdf_data else: self.sdf_data = surface_sdf_data print_data_distribution(self.sdf_data) @@ -403,7 +404,7 @@ class Trainer: logger.debug(f"Patch [{patch_id:2d}] | train pnt number {points.shape[0]}") mnfld_pnts = points[:, 0:3] - gt_sdf = torch.zeros(mnfld_pnts.shape[0], device=self.device) + gt_sdf = points[:, -1] normals = points[:, 3:6]