From 94678c5c174aa55722d57bdb71b0ae5a70e2ed30 Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 22 Nov 2024 00:40:49 +0800 Subject: [PATCH] =?UTF-8?q?todo:=20=E9=85=8D=E7=BD=AE=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 21 +++++++++++---------- brep2sdf/train.py | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index f5a0132..f2a392e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -4,18 +4,18 @@ from typing import Tuple, Optional @dataclass class ModelConfig: """模型相关配置""" - brep_feature_dim: int = 32 + brep_feature_dim: int = 16 use_cf: bool = True - embed_dim: int = 384 # 3 的 倍数 - latent_dim: int = 64 + embed_dim: int = 192 # 3 的 倍数 + latent_dim: int = 16 # 点云采样配置 - num_surf_points: int = 16 # 每个面采样点数 - num_edge_points: int = 4 # 每条边采样点数 + num_surf_points: int = 8 # 每个面采样点数 + num_edge_points: int = 2 # 每条边采样点数 # Transformer相关配置 - num_transformer_layers: int = 6 - num_attention_heads: int = 8 + num_transformer_layers: int = 4 + num_attention_heads: int = 6 transformer_dim_feedforward: int = 512 transformer_dropout: float = 0.1 @@ -26,8 +26,9 @@ class ModelConfig: @dataclass class DataConfig: """数据相关配置""" - max_face: int = 64 - max_edge: int = 64 + max_face: int = 8 + max_edge: int = 16 + num_query_points: int = 4096 # 限制查询点数量,sdf 采样点数 本来是 128*128*128 ,在data load时随机采样 bbox_scaled: float = 1.0 # 数据路径 @@ -50,7 +51,7 @@ class DataConfig: class TrainConfig: """训练相关配置""" # 基本训练参数 - batch_size: int = 8 + batch_size: int = 1 num_workers: int = 4 num_epochs: int = 100 learning_rate: float = 1e-4 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 909f941..1a85692 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -72,14 +72,14 @@ class Trainer: batch_size=config.train.batch_size, shuffle=True, num_workers=config.train.num_workers, - pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 + pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 ) self.val_loader = DataLoader( self.val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=config.train.num_workers, - pin_memory=use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 + pin_memory=False #use_pin_memory # 根据设备类型设置,是否将数据固定在内存中 ) # 初始化模型