Browse Source

feat: Add TorchScript model tracing and export functionality

- Implemented `tracing()` method in training pipeline to export model as TorchScript
- Created dedicated TorchScript directory in checkpoints path
- Added model tracing with example input tensor
- Saved traced model as `model_h.pt` in the TorchScript directory
NH-Rep
mckay 1 week ago
parent
commit
03b2858e81
  1. 18
      code/conversion/train.py

18
code/conversion/train.py

@ -88,6 +88,7 @@ class NHREPNet_Training:
except Exception as e:
logger.error(f"训练过程中发生错误: {str(e)}")
break
self.tracing()
def train_one_epoch(self, epoch, patch_id, patch_id_n, n_patch_batch, n_patch_last, n_branch, n_batchsize):
#logger.info(f"Epoch {epoch}/{self.nepochs} 开始")
@ -211,10 +212,12 @@ class NHREPNet_Training:
self.checkpoints_path = os.path.join("../exps/single_shape",name_prefix, "checkpoints")
self.ModelParameters_path = os.path.join(self.checkpoints_path, "ModelParameters")
self.OptimizerParameters_path = os.path.join(self.checkpoints_path, "OptimizerParameters")
self.TorchScript_path = os.path.join(self.checkpoints_path, "TorchScript")
# 创建目录
os.makedirs(self.ModelParameters_path, exist_ok=True)
os.makedirs(self.OptimizerParameters_path, exist_ok=True)
os.makedirs(self.TorchScript_path, exist_ok=True)
def save_checkpoints(self, epoch):
torch.save(
@ -230,6 +233,21 @@ class NHREPNet_Training:
{"epoch": epoch, "optimizer_state_dict": self.scheduler.optimizer.state_dict()},
os.path.join(self.OptimizerParameters_path, "latest.pth"))
def tracing(self):
csg_tree, flag_convex = self.dataset.get_csg_tree()
network = get_class(self.conf.get_string('train.network_class'))(
d_in=self.d_in,
n_branch=int(torch.max(self.feature_mask).item()),
csg_tree=csg_tree,
flag_convex=flag_convex,
**self.conf.get_config('network.inputs')
).to(self.device)
#trace
example = torch.rand(224,3).to(self.device)
traced_script_module = torch.jit.trace(network, example)
traced_script_module.save(os.path.join(self.TorchScript_path, "model_h.pt"))
if __name__ == "__main__":
name_prefix = 'broken_bullet_50k'
conf = ConfigFactory.parse_file('./conversion/setup.conf')

Loading…
Cancel
Save