|
|
@ -4,18 +4,18 @@ from typing import Tuple, Optional |
|
|
|
@dataclass |
|
|
|
class ModelConfig: |
|
|
|
"""模型相关配置""" |
|
|
|
brep_feature_dim: int = 32 |
|
|
|
brep_feature_dim: int = 16 |
|
|
|
use_cf: bool = True |
|
|
|
embed_dim: int = 384 # 3 的 倍数 |
|
|
|
latent_dim: int = 64 |
|
|
|
embed_dim: int = 192 # 3 的 倍数 |
|
|
|
latent_dim: int = 16 |
|
|
|
|
|
|
|
# 点云采样配置 |
|
|
|
num_surf_points: int = 16 # 每个面采样点数 |
|
|
|
num_edge_points: int = 4 # 每条边采样点数 |
|
|
|
num_surf_points: int = 8 # 每个面采样点数 |
|
|
|
num_edge_points: int = 2 # 每条边采样点数 |
|
|
|
|
|
|
|
# Transformer相关配置 |
|
|
|
num_transformer_layers: int = 6 |
|
|
|
num_attention_heads: int = 8 |
|
|
|
num_transformer_layers: int = 4 |
|
|
|
num_attention_heads: int = 6 |
|
|
|
transformer_dim_feedforward: int = 512 |
|
|
|
transformer_dropout: float = 0.1 |
|
|
|
|
|
|
@ -26,8 +26,9 @@ class ModelConfig: |
|
|
|
@dataclass |
|
|
|
class DataConfig: |
|
|
|
"""数据相关配置""" |
|
|
|
max_face: int = 64 |
|
|
|
max_edge: int = 64 |
|
|
|
max_face: int = 8 |
|
|
|
max_edge: int = 16 |
|
|
|
num_query_points: int = 4096 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 |
|
|
|
bbox_scaled: float = 1.0 |
|
|
|
|
|
|
|
# 数据路径 |
|
|
@ -50,7 +51,7 @@ class DataConfig: |
|
|
|
class TrainConfig: |
|
|
|
"""训练相关配置""" |
|
|
|
# 基本训练参数 |
|
|
|
batch_size: int = 8 |
|
|
|
batch_size: int = 1 |
|
|
|
num_workers: int = 4 |
|
|
|
num_epochs: int = 100 |
|
|
|
learning_rate: float = 1e-4 |
|
|
|