Browse Source

todo: 配置修改

main
mckay 7 months ago
parent
commit
94678c5c17
  1. 21
      brep2sdf/config/default_config.py
  2. 4
      brep2sdf/train.py

21
brep2sdf/config/default_config.py

@ -4,18 +4,18 @@ from typing import Tuple, Optional
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""模型相关配置""" """模型相关配置"""
brep_feature_dim: int = 32 brep_feature_dim: int = 16
use_cf: bool = True use_cf: bool = True
embed_dim: int = 384 # 3 的 倍数 embed_dim: int = 192 # 3 的 倍数
latent_dim: int = 64 latent_dim: int = 16
# 点云采样配置 # 点云采样配置
num_surf_points: int = 16 # 每个面采样点数 num_surf_points: int = 8 # 每个面采样点数
num_edge_points: int = 4 # 每条边采样点数 num_edge_points: int = 2 # 每条边采样点数
# Transformer相关配置 # Transformer相关配置
num_transformer_layers: int = 6 num_transformer_layers: int = 4
num_attention_heads: int = 8 num_attention_heads: int = 6
transformer_dim_feedforward: int = 512 transformer_dim_feedforward: int = 512
transformer_dropout: float = 0.1 transformer_dropout: float = 0.1
@ -26,8 +26,9 @@ class ModelConfig:
@dataclass @dataclass
class DataConfig: class DataConfig:
"""数据相关配置""" """数据相关配置"""
max_face: int = 64 max_face: int = 8
max_edge: int = 64 max_edge: int = 16
num_query_points: int = 4096 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样
bbox_scaled: float = 1.0 bbox_scaled: float = 1.0
# 数据路径 # 数据路径
@ -50,7 +51,7 @@ class DataConfig:
class TrainConfig: class TrainConfig:
"""训练相关配置""" """训练相关配置"""
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 1
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 100 num_epochs: int = 100
learning_rate: float = 1e-4 learning_rate: float = 1e-4

4
brep2sdf/train.py

@ -72,14 +72,14 @@ class Trainer:
batch_size=config.train.batch_size, batch_size=config.train.batch_size,
shuffle=True, shuffle=True,
num_workers=config.train.num_workers, num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
) )
self.val_loader = DataLoader( self.val_loader = DataLoader(
self.val_dataset, self.val_dataset,
batch_size=config.train.batch_size, batch_size=config.train.batch_size,
shuffle=False, shuffle=False,
num_workers=config.train.num_workers, num_workers=config.train.num_workers,
pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中
) )
# 初始化模型 # 初始化模型

Loading…
Cancel
Save