diff --git a/code/conversion/train.py b/code/conversion/train.py index e261604..712cd69 100644 --- a/code/conversion/train.py +++ b/code/conversion/train.py @@ -88,6 +88,7 @@ class NHREPNet_Training: except Exception as e: logger.error(f"训练过程中发生错误: {str(e)}") break + self.tracing() def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize): #logger.info(f"Epoch {epoch}/{self.nepochs} 开始") @@ -211,10 +212,12 @@ class NHREPNet_Training: self.checkpoints_path = os.path.join("../exps/single_shape",name_prefix, "checkpoints") self.ModelParameters_path = os.path.join(self.checkpoints_path, "ModelParameters") self.OptimizerParameters_path = os.path.join(self.checkpoints_path, "OptimizerParameters") + self.TorchScript_path = os.path.join(self.checkpoints_path, "TorchScript") # 创建目录 os.makedirs(self.ModelParameters_path, exist_ok=True) os.makedirs(self.OptimizerParameters_path, exist_ok=True) + os.makedirs(self.TorchScript_path, exist_ok=True) def save_checkpoints(self, epoch): torch.save( @@ -229,6 +232,21 @@ class NHREPNet_Training: torch.save( {"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()}, os.path.join(self.OptimizerParameters_path, "latest.pth")) + + def tracing(self): + csg_tree, flag_convex = self.dataset.get_csg_tree() + network = get_class(self.conf.get_string('train.network_class'))( + d_in=self.d_in, + n_branch=int(torch.max(self.feature_mask).item()), + csg_tree=csg_tree, + flag_convex=flag_convex, + **self.conf.get_config('network.inputs') + ).to(self.device) + #trace + example = torch.rand(224,3).to(self.device) + traced_script_module = torch.jit.trace(network, example) + traced_script_module.save(os.path.join(self.TorchScript_path, "model_h.pt")) + if __name__ == "__main__": name_prefix = 'broken_bullet_50k'