Browse Source

BREAKPOINT: first attremp based on brepgen encoder and deepsef decoder

final
mckay 3 months ago
parent
commit
cabc98090e
  1. 16
      brep2sdf/config/default_config.py
  2. 2
      brep2sdf/data/data.py
  3. 2
      brep2sdf/networks/loss.py
  4. 2
      brep2sdf/train.py

16
brep2sdf/config/default_config.py

@ -7,11 +7,11 @@ class ModelConfig:
brep_feature_dim: int = 16
use_cf: bool = True
embed_dim: int = 768 # 3 的 倍数
latent_dim: int = 16
latent_dim: int = 32
# 点云采样配置
num_surf_points: int = 8 # 每个面采样点数
num_edge_points: int = 2 # 每条边采样点数
num_surf_points: int = 16 # 每个面采样点数
num_edge_points: int = 4 # 每条边采样点数
# Transformer相关配置
num_transformer_layers: int = 4
@ -27,8 +27,8 @@ class ModelConfig:
@dataclass
class DataConfig:
"""数据相关配置"""
max_face: int = 32
max_edge: int = 128
max_face: int = 16
max_edge: int = 64
num_query_points: int = 32*32*32 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样
bbox_scaled: float = 1.0
@ -57,9 +57,9 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 100
learning_rate: float = 1
min_lr: float = 1e-1
num_epochs: int = 500
learning_rate: float = 0.01
min_lr: float = 1e-5
weight_decay: float = 0.01
# 梯度和损失相关

2
brep2sdf/data/data.py

@ -88,6 +88,8 @@ class BRepSDFDataset(Dataset):
idx for idx in range(len(self.brep_data_list))
if (self._get_brep_face_and_edge(self.brep_data_list[idx]) <= (self.max_face, self.max_edge))
]
#filtered_indices = filtered_indices[0:8] # TODO rm
# Use filtered_indices to update brep_data_list and sdf_data_list
self.brep_data_list = [self.brep_data_list[idx] for idx in filtered_indices]

2
brep2sdf/networks/loss.py

@ -16,6 +16,7 @@ class Brep2SDFLoss(nn.Module):
self.warmup_epochs = warmup_epochs
self.l1_loss = nn.L1Loss(reduction='mean')
self.mse_loss = nn.MSELoss(reduction='mean')
def forward(self, pred_sdf, gt_sdf, points=None, epoch=None):
"""计算SDF预测的损失
@ -33,6 +34,7 @@ class Brep2SDFLoss(nn.Module):
# 2. 计算基础L1损失
l1_loss = self.l1_loss(pred_sdf, gt_sdf)
#mse_loss = self.mse_loss(pred_sdf, gt_sdf)
# 3. 计算梯度损失(如果提供了points)
grad_loss = 0

2
brep2sdf/train.py

@ -29,6 +29,8 @@ def main():
init_timeout=180, # 增加超时时间
_disable_stats=True, # 禁用统计
_disable_meta=True, # 禁用元数据
#http_proxy='https://sub.toline.link/s/fch9w',
#https_proxy='https://sub.toline.link/s/fch9w'
),
mode="offline" # 使用离线模式
)

Loading…
Cancel
Save