From 21f8369e2022fca2ff29b7f201cf0ab57d9e1354 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 10 Apr 2025 12:47:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9torch.jit.script=EF=BC=8Cload?= =?UTF-8?q?=E5=92=8Csave=E6=96=B9=E4=BE=BF=E3=80=82=E4=BD=86=E6=98=AF?= =?UTF-8?q?=E7=89=BA=E7=89=B2=E4=BA=86=E9=9A=8F=E6=9C=BA=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=EF=BC=8C=E6=80=A7=E8=83=BD=E9=99=8D=E4=BA=86=E5=BE=88=E5=A4=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 56 +++++++++++++++++++++++------------- brep2sdf/networks/octree.py | 44 +++++++++++++++++++++------- brep2sdf/train.py | 16 ++++++----- 3 files changed, 78 insertions(+), 38 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 02008d8..8c3d4d1 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -73,34 +73,51 @@ class Encoder(nn.Module): self.octree = octree self.feature_dim = feature_dim + # 初始化叶子节点参数 + self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数 + self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index # 为所有叶子节点注册可学习参数 self._init_parameters() def _init_parameters(self): """为所有叶子节点初始化特征参数""" - # 使用字典保存所有参数,避免动态属性 - self._leaf_parameters = nn.ParameterDict() - - # 递归遍历树结构 - def _register_params(node, path=""): - #logger.debug(node.is_leaf()) + # 使用栈模拟递归遍历(避免递归) + stack = [(self.octree, "root")] # (当前节点, 当前路径) + param_index = 0 # 参数索引计数器 + + while stack: + node, path = stack.pop() + if node.is_leaf(): + # 如果是叶子节点,初始化参数 param_name = f"leaf_{path}" - self._leaf_parameters[param_name] = nn.Parameter( - torch.randn(8, self.feature_dim) # 8个顶点的特征 - ) + self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征 + self.param_key_to_index[param_name] = param_index # 记录索引 node.set_param_key(param_name) # 为节点存储参数键 - #logger.debug(param_name) - #logger.debug(node.param_key) + param_index += 1 else: + # 如果不是叶子节点,继续遍历子节点 for i, child in enumerate(node.child_nodes): - _register_params(child, f"{path}_{i}") + if child is not None: + stack.append((child, f"{path}_{i}")) + + def get_leaf_parameter(self, param_key: str) -> torch.Tensor: + """ + 获取叶子节点的特征参数 + :param param_key: 叶子节点的参数键 + :return: 对应的参数 + """ + if param_key not in self.param_key_to_index: + raise KeyError(f"Invalid param_key: {param_key}") - _register_params(self.octree, "root") - - def get_leaf_parameter(self, node): - """获取叶子节点的特征参数""" - return self._leaf_parameters[node.param_key] + target_index = self.param_key_to_index[param_key] + + # 使用枚举代替动态索引 + for index, param in enumerate(self._leaf_parameters): + if index == target_index: + return param + + raise IndexError(f"Index {target_index} not found in ParameterList") def forward(self, query_points: torch.Tensor) -> torch.Tensor: """ @@ -116,12 +133,11 @@ class Encoder(nn.Module): for i in range(batch_size): # 1. 在八叉树中查找包含该点的叶子节点 - leaf_node = self.octree.find_leaf(query_points[i]) + bbox, param_key, _ = self.octree.find_leaf(query_points[i]) #logger.debug(leaf_node.param_key) # 2. 获取该节点的特征参数 - bbox = leaf_node.bbox - node_features = self.get_leaf_parameter(leaf_node) + node_features = self.get_leaf_parameter(param_key) # 3. 使用三线性插值计算特征 # (这里需要实现你的插值逻辑) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index c0d2444..4b2a308 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -1,6 +1,6 @@ -from typing import Tuple, List +from typing import Tuple, List, cast, Dict, Any, Tuple import torch import torch.nn as nn @@ -35,9 +35,9 @@ class OctreeNode(nn.Module): super().__init__() self.bbox = bbox # 节点的边界框 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 - self.child_nodes: List['OctreeNode'] = [] # 子节点列表 + self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表 self.face_indices = face_indices - self.param_key = None + self.param_key = "" #self.patch_feature_volume = None # 补丁特征体积,only leaf has self._is_leaf = True #print(f"box shape: {self.bbox.shape}") @@ -103,7 +103,6 @@ class OctreeNode(nn.Module): ]) # 为每个子包围盒创建子节点,并分配相交的面 - self.child_nodes = [] for bbox in child_bboxes: # 找到与子包围盒相交的面 intersecting_faces = [] @@ -142,14 +141,35 @@ class OctreeNode(nn.Module): # 使用布尔比较结果计算索引 index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() - return index.unsqueeze(0) + return index.item() - def find_leaf(self, query_point:torch.Tensor): - # 从根节点开始递归查找包含该点的叶子节点 + def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]: + """ + 查找包含给定点的叶子节点,并返回其信息(以元组形式) + :param query_point: 待查找的点 + :return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) + """ + # 如果当前节点是叶子节点,返回其信息 if self._is_leaf: - return self - else: - index = self.get_child_index(query_point) + #logger.info(f"{self.bbox}, {self.param_key}, {True}") + return (self.bbox, self.param_key, True) + + # 计算查询点所在的子节点索引 + index = self.get_child_index(query_point) + + # 遍历子节点列表,找到对应的子节点 + for i, child_node in enumerate(self.child_nodes): + if i == index and child_node is not None: + # 递归调用子节点的 find_leaf 方法 + result = child_node.find_leaf(query_point) + + # 确保返回值是一个元组 + assert isinstance(result, tuple), f"Unexpected return type: {type(result)}" + return result + + # 如果找不到有效的子节点,抛出异常 + raise IndexError(f"Invalid child node index: {index}") + ''' try: # 直接访问子节点,不进行显式检查 return self.child_nodes[index].find_leaf(query_point) @@ -162,7 +182,9 @@ class OctreeNode(nn.Module): f"Depth info: {self.max_depth}" ) raise e + ''' + ''' def get_feature_vector(self, query_point:torch.Tensor): """ 预测给定点的 SDF 值 @@ -188,7 +210,7 @@ class OctreeNode(nn.Module): f"Depth info: {self.max_depth}" ) raise e - + ''' diff --git a/brep2sdf/train.py b/brep2sdf/train.py index d6e71b8..474f856 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -79,7 +79,11 @@ class Trainer: self.base_name = self.model_name + ".xyz" data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) if os.path.exists(data_path) and not args.force_reprocess: - self.data = load_brep_file(data_path) + try: + self.data = load_brep_file(data_path) + except Exception as e: + logger.error(f"fail to load {data_path}, {str(e)}") + raise e if args.use_normal and self.data.get("surf_pnt_normals", None) is None: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) else: @@ -203,7 +207,7 @@ class Trainer: total_loss += loss.item() # 记录训练进度 - logger.info(f'Train Epoch: {epoch}]\t' + logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {loss.item():.6f}') return total_loss @@ -274,11 +278,9 @@ class Trainer: def _tracing_model(self): """保存模型""" self.model.eval() - example_input = torch.rand(10, 3, device=self.device) - # 3. 在no_grad上下文中执行追踪 - with torch.no_grad(): - traced_model = torch.jit.trace(self.model, example_input) - torch.jit.save(traced_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") + # 确保模型中的所有逻辑都兼容 TorchScript + scripted_model = torch.jit.script(self.model) + torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态"""