Browse Source

todo: 配置修改

main
mckay 4 months ago
parent
commit
94678c5c17
  1. 21
      brep2sdf/config/default_config.py
  2. 4
      brep2sdf/train.py

21
brep2sdf/config/default_config.py

@ -4,18 +4,18 @@ from typing import Tuple, Optional
@dataclass
class ModelConfig:
"""模型相关配置"""
brep_feature_dim: int = 32
brep_feature_dim: int = 16
use_cf: bool = True
embed_dim: int = 384 # 3 的 倍数
latent_dim: int = 64
embed_dim: int = 192 # 3 的 倍数
latent_dim: int = 16
# 点云采样配置
num_surf_points: int = 16 # 每个面采样点数
num_edge_points: int = 4 # 每条边采样点数
num_surf_points: int = 8 # 每个面采样点数
num_edge_points: int = 2 # 每条边采样点数
# Transformer相关配置
num_transformer_layers: int = 6
num_attention_heads: int = 8
num_transformer_layers: int = 4
num_attention_heads: int = 6
transformer_dim_feedforward: int = 512
transformer_dropout: float = 0.1
@ -26,8 +26,9 @@ class ModelConfig:
@dataclass
class DataConfig:
"""数据相关配置"""
max_face: int = 64
max_edge: int = 64
max_face: int = 8
max_edge: int = 16
num_query_points: int = 4096 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样
bbox_scaled: float = 1.0
# 数据路径
@ -50,7 +51,7 @@ class DataConfig:
class TrainConfig:
"""训练相关配置"""
# 基本训练参数
batch_size: int = 8
batch_size: int = 1
num_workers: int = 4
num_epochs: int = 100
learning_rate: float = 1e-4

4
brep2sdf/train.py

@ -72,14 +72,14 @@ class Trainer:
batch_size=config.train.batch_size,
shuffle=True,
num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
)
# 初始化模型

Loading…
Cancel
Save