diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 924ccb6..909f941 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -83,12 +83,7 @@ class Trainer: ) # 初始化模型 - self.model = BRepToSDF( - brep_feature_dim=config.model.brep_feature_dim, - use_cf=config.model.use_cf, - embed_dim=config.model.embed_dim, - latent_dim=config.model.latent_dim - ).to(self.device) + self.model = BRepToSDF(config).to(self.device) # 初始化优化器 self.optimizer = optim.AdamW(