From cabc98090e4dfe35948166e87493f91b2ea32c3e Mon Sep 17 00:00:00 2001 From: mckay Date: Sun, 23 Mar 2025 15:43:51 +0800 Subject: [PATCH] BREAKPOINT: first attremp based on brepgen encoder and deepsef decoder --- brep2sdf/config/default_config.py | 16 ++++++++-------- brep2sdf/data/data.py | 2 ++ brep2sdf/networks/loss.py | 2 ++ brep2sdf/train.py | 2 ++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 2c19bde..ac9aab0 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -7,11 +7,11 @@ class ModelConfig: brep_feature_dim: int = 16 use_cf: bool = True embed_dim: int = 768 # 3 的 倍数 - latent_dim: int = 16 + latent_dim: int = 32 # 点云采样配置 - num_surf_points: int = 8 # 每个面采样点数 - num_edge_points: int = 2 # 每条边采样点数 + num_surf_points: int = 16 # 每个面采样点数 + num_edge_points: int = 4 # 每条边采样点数 # Transformer相关配置 num_transformer_layers: int = 4 @@ -27,8 +27,8 @@ class ModelConfig: @dataclass class DataConfig: """数据相关配置""" - max_face: int = 32 - max_edge: int = 128 + max_face: int = 16 + max_edge: int = 64 num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 bbox_scaled: float = 1.0 @@ -57,9 +57,9 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 100 - learning_rate: float = 1 - min_lr: float = 1e-1 + num_epochs: int = 500 + learning_rate: float = 0.01 + min_lr: float = 1e-5 weight_decay: float = 0.01 # 梯度和损失相关 diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index d970cdf..3d8e7bb 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -88,6 +88,8 @@ class BRepSDFDataset(Dataset): idx for idx in range(len(self.brep_data_list)) if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge)) ] + + #filtered_indices = filtered_indices[0:8] # TODO rm # Use filtered_indices to update brep_data_list and sdf_data_list self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices] diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 540488c..034ec62 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -16,6 +16,7 @@ class Brep2SDFLoss(nn.Module): self.warmup_epochs = warmup_epochs self.l1_loss = nn.L1Loss(reduction='mean') + self.mse_loss = nn.MSELoss(reduction='mean') def forward(self, pred_sdf, gt_sdf, points=None, epoch=None): """计算SDF预测的损失 @@ -33,6 +34,7 @@ class Brep2SDFLoss(nn.Module): # 2. 计算基础L1损失 l1_loss = self.l1_loss(pred_sdf, gt_sdf) + #mse_loss = self.mse_loss(pred_sdf, gt_sdf) # 3. 计算梯度损失(如果提供了points) grad_loss = 0 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 2a65832..f93472e 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -29,6 +29,8 @@ def main(): init_timeout=180, # 增加超时时间 _disable_stats=True, # 禁用统计 _disable_meta=True, # 禁用元数据 + #http_proxy='https://sub.toline.link/s/fch9w', + #https_proxy='https://sub.toline.link/s/fch9w' ), mode="offline" # 使用离线模式 )