Browse Source

兼容torch.jit.script,load和save方便。但是牺牲了随机访问,性能降了很多

final
mckay 2 months ago
parent
commit
21f8369e20
  1. 52
      brep2sdf/networks/encoder.py
  2. 44
      brep2sdf/networks/octree.py
  3. 16
      brep2sdf/train.py

52
brep2sdf/networks/encoder.py

@ -73,34 +73,51 @@ class Encoder(nn.Module):
self.octree = octree self.octree = octree
self.feature_dim = feature_dim self.feature_dim = feature_dim
# 初始化叶子节点参数
self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数
self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index
# 为所有叶子节点注册可学习参数 # 为所有叶子节点注册可学习参数
self._init_parameters() self._init_parameters()
def _init_parameters(self): def _init_parameters(self):
"""为所有叶子节点初始化特征参数""" """为所有叶子节点初始化特征参数"""
# 使用字典保存所有参数,避免动态属性 # 使用栈模拟递归遍历(避免递归)
self._leaf_parameters = nn.ParameterDict() stack = [(self.octree, "root")] # (当前节点, 当前路径)
param_index = 0 # 参数索引计数器
while stack:
node, path = stack.pop()
# 递归遍历树结构
def _register_params(node, path=""):
#logger.debug(node.is_leaf())
if node.is_leaf(): if node.is_leaf():
# 如果是叶子节点,初始化参数
param_name = f"leaf_{path}" param_name = f"leaf_{path}"
self._leaf_parameters[param_name] = nn.Parameter( self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征
torch.randn(8, self.feature_dim) # 8个顶点的特征 self.param_key_to_index[param_name] = param_index # 记录索引
)
node.set_param_key(param_name) # 为节点存储参数键 node.set_param_key(param_name) # 为节点存储参数键
#logger.debug(param_name) param_index += 1
#logger.debug(node.param_key)
else: else:
# 如果不是叶子节点,继续遍历子节点
for i, child in enumerate(node.child_nodes): for i, child in enumerate(node.child_nodes):
_register_params(child, f"{path}_{i}") if child is not None:
stack.append((child, f"{path}_{i}"))
def get_leaf_parameter(self, param_key: str) -> torch.Tensor:
"""
获取叶子节点的特征参数
:param param_key: 叶子节点的参数键
:return: 对应的参数
"""
if param_key not in self.param_key_to_index:
raise KeyError(f"Invalid param_key: {param_key}")
target_index = self.param_key_to_index[param_key]
_register_params(self.octree, "root") # 使用枚举代替动态索引
for index, param in enumerate(self._leaf_parameters):
if index == target_index:
return param
def get_leaf_parameter(self, node): raise IndexError(f"Index {target_index} not found in ParameterList")
"""获取叶子节点的特征参数"""
return self._leaf_parameters[node.param_key]
def forward(self, query_points: torch.Tensor) -> torch.Tensor: def forward(self, query_points: torch.Tensor) -> torch.Tensor:
""" """
@ -116,12 +133,11 @@ class Encoder(nn.Module):
for i in range(batch_size): for i in range(batch_size):
# 1. 在八叉树中查找包含该点的叶子节点 # 1. 在八叉树中查找包含该点的叶子节点
leaf_node = self.octree.find_leaf(query_points[i]) bbox, param_key, _ = self.octree.find_leaf(query_points[i])
#logger.debug(leaf_node.param_key) #logger.debug(leaf_node.param_key)
# 2. 获取该节点的特征参数 # 2. 获取该节点的特征参数
bbox = leaf_node.bbox node_features = self.get_leaf_parameter(param_key)
node_features = self.get_leaf_parameter(leaf_node)
# 3. 使用三线性插值计算特征 # 3. 使用三线性插值计算特征
# (这里需要实现你的插值逻辑) # (这里需要实现你的插值逻辑)

44
brep2sdf/networks/octree.py

@ -1,6 +1,6 @@
from typing import Tuple, List from typing import Tuple, List, cast, Dict, Any, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -35,9 +35,9 @@ class OctreeNode(nn.Module):
super().__init__() super().__init__()
self.bbox = bbox # 节点的边界框 self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.child_nodes: List['OctreeNode'] = [] # 子节点列表 self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表
self.face_indices = face_indices self.face_indices = face_indices
self.param_key = None self.param_key = ""
#self.patch_feature_volume = None # 补丁特征体积,only leaf has #self.patch_feature_volume = None # 补丁特征体积,only leaf has
self._is_leaf = True self._is_leaf = True
#print(f"box shape: {self.bbox.shape}") #print(f"box shape: {self.bbox.shape}")
@ -103,7 +103,6 @@ class OctreeNode(nn.Module):
]) ])
# 为每个子包围盒创建子节点,并分配相交的面 # 为每个子包围盒创建子节点,并分配相交的面
self.child_nodes = []
for bbox in child_bboxes: for bbox in child_bboxes:
# 找到与子包围盒相交的面 # 找到与子包围盒相交的面
intersecting_faces = [] intersecting_faces = []
@ -142,14 +141,35 @@ class OctreeNode(nn.Module):
# 使用布尔比较结果计算索引 # 使用布尔比较结果计算索引
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
return index.unsqueeze(0) return index.item()
def find_leaf(self, query_point:torch.Tensor): def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]:
# 从根节点开始递归查找包含该点的叶子节点 """
查找包含给定点的叶子节点并返回其信息以元组形式
:param query_point: 待查找的点
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf)
"""
# 如果当前节点是叶子节点,返回其信息
if self._is_leaf: if self._is_leaf:
return self #logger.info(f"{self.bbox}, {self.param_key}, {True}")
else: return (self.bbox, self.param_key, True)
index = self.get_child_index(query_point)
# 计算查询点所在的子节点索引
index = self.get_child_index(query_point)
# 遍历子节点列表,找到对应的子节点
for i, child_node in enumerate(self.child_nodes):
if i == index and child_node is not None:
# 递归调用子节点的 find_leaf 方法
result = child_node.find_leaf(query_point)
# 确保返回值是一个元组
assert isinstance(result, tuple), f"Unexpected return type: {type(result)}"
return result
# 如果找不到有效的子节点,抛出异常
raise IndexError(f"Invalid child node index: {index}")
'''
try: try:
# 直接访问子节点,不进行显式检查 # 直接访问子节点,不进行显式检查
return self.child_nodes[index].find_leaf(query_point) return self.child_nodes[index].find_leaf(query_point)
@ -162,7 +182,9 @@ class OctreeNode(nn.Module):
f"Depth info: {self.max_depth}" f"Depth info: {self.max_depth}"
) )
raise e raise e
'''
'''
def get_feature_vector(self, query_point:torch.Tensor): def get_feature_vector(self, query_point:torch.Tensor):
""" """
预测给定点的 SDF 预测给定点的 SDF
@ -188,7 +210,7 @@ class OctreeNode(nn.Module):
f"Depth info: {self.max_depth}" f"Depth info: {self.max_depth}"
) )
raise e raise e
'''

16
brep2sdf/train.py

@ -79,7 +79,11 @@ class Trainer:
self.base_name = self.model_name + ".xyz" self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path) and not args.force_reprocess: if os.path.exists(data_path) and not args.force_reprocess:
self.data = load_brep_file(data_path) try:
self.data = load_brep_file(data_path)
except Exception as e:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None: if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
else: else:
@ -203,7 +207,7 @@ class Trainer:
total_loss += loss.item() total_loss += loss.item()
# 记录训练进度 # 记录训练进度
logger.info(f'Train Epoch: {epoch}]\t' logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss.item():.6f}') f'Loss: {loss.item():.6f}')
return total_loss return total_loss
@ -274,11 +278,9 @@ class Trainer:
def _tracing_model(self): def _tracing_model(self):
"""保存模型""" """保存模型"""
self.model.eval() self.model.eval()
example_input = torch.rand(10, 3, device=self.device) # 确保模型中的所有逻辑都兼容 TorchScript
# 3. 在no_grad上下文中执行追踪 scripted_model = torch.jit.script(self.model)
with torch.no_grad(): torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
traced_model = torch.jit.trace(self.model, example_input)
torch.jit.save(traced_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path): def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态""" """从检查点恢复训练状态"""

Loading…
Cancel
Save