Browse Source

兼容torch.jit.script,load和save方便。但是牺牲了随机访问,性能降了很多

final
mckay 2 months ago
parent
commit
21f8369e20
  1. 52
      brep2sdf/networks/encoder.py
  2. 42
      brep2sdf/networks/octree.py
  3. 14
      brep2sdf/train.py

52
brep2sdf/networks/encoder.py

@ -73,34 +73,51 @@ class Encoder(nn.Module):
self.octree = octree
self.feature_dim = feature_dim
# 初始化叶子节点参数
self._leaf_parameters = nn.ParameterList() # 使用 ParameterList 存储参数
self.param_key_to_index: Dict[str, int] = {} # 字典映射:param_key -> index
# 为所有叶子节点注册可学习参数
self._init_parameters()
def _init_parameters(self):
"""为所有叶子节点初始化特征参数"""
# 使用字典保存所有参数,避免动态属性
self._leaf_parameters = nn.ParameterDict()
# 使用栈模拟递归遍历(避免递归)
stack = [(self.octree, "root")] # (当前节点, 当前路径)
param_index = 0 # 参数索引计数器
while stack:
node, path = stack.pop()
# 递归遍历树结构
def _register_params(node, path=""):
#logger.debug(node.is_leaf())
if node.is_leaf():
# 如果是叶子节点,初始化参数
param_name = f"leaf_{path}"
self._leaf_parameters[param_name] = nn.Parameter(
torch.randn(8, self.feature_dim) # 8个顶点的特征
)
self._leaf_parameters.append(nn.Parameter(torch.randn(8, self.feature_dim))) # 8个顶点的特征
self.param_key_to_index[param_name] = param_index # 记录索引
node.set_param_key(param_name) # 为节点存储参数键
#logger.debug(param_name)
#logger.debug(node.param_key)
param_index += 1
else:
# 如果不是叶子节点,继续遍历子节点
for i, child in enumerate(node.child_nodes):
_register_params(child, f"{path}_{i}")
if child is not None:
stack.append((child, f"{path}_{i}"))
def get_leaf_parameter(self, param_key: str) -> torch.Tensor:
"""
获取叶子节点的特征参数
:param param_key: 叶子节点的参数键
:return: 对应的参数
"""
if param_key not in self.param_key_to_index:
raise KeyError(f"Invalid param_key: {param_key}")
target_index = self.param_key_to_index[param_key]
_register_params(self.octree, "root")
# 使用枚举代替动态索引
for index, param in enumerate(self._leaf_parameters):
if index == target_index:
return param
def get_leaf_parameter(self, node):
"""获取叶子节点的特征参数"""
return self._leaf_parameters[node.param_key]
raise IndexError(f"Index {target_index} not found in ParameterList")
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""
@ -116,12 +133,11 @@ class Encoder(nn.Module):
for i in range(batch_size):
# 1. 在八叉树中查找包含该点的叶子节点
leaf_node = self.octree.find_leaf(query_points[i])
bbox, param_key, _ = self.octree.find_leaf(query_points[i])
#logger.debug(leaf_node.param_key)
# 2. 获取该节点的特征参数
bbox = leaf_node.bbox
node_features = self.get_leaf_parameter(leaf_node)
node_features = self.get_leaf_parameter(param_key)
# 3. 使用三线性插值计算特征
# (这里需要实现你的插值逻辑)

42
brep2sdf/networks/octree.py

@ -1,6 +1,6 @@
from typing import Tuple, List
from typing import Tuple, List, cast, Dict, Any, Tuple
import torch
import torch.nn as nn
@ -35,9 +35,9 @@ class OctreeNode(nn.Module):
super().__init__()
self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.child_nodes: List['OctreeNode'] = [] # 子节点列表
self.child_nodes: torch.nn.ModuleList = torch.nn.ModuleList() # 子节点列表
self.face_indices = face_indices
self.param_key = None
self.param_key = ""
#self.patch_feature_volume = None # 补丁特征体积,only leaf has
self._is_leaf = True
#print(f"box shape: {self.bbox.shape}")
@ -103,7 +103,6 @@ class OctreeNode(nn.Module):
])
# 为每个子包围盒创建子节点,并分配相交的面
self.child_nodes = []
for bbox in child_bboxes:
# 找到与子包围盒相交的面
intersecting_faces = []
@ -142,14 +141,35 @@ class OctreeNode(nn.Module):
# 使用布尔比较结果计算索引
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
return index.unsqueeze(0)
return index.item()
def find_leaf(self, query_point:torch.Tensor):
# 从根节点开始递归查找包含该点的叶子节点
def find_leaf(self, query_point: torch.Tensor) -> Tuple[torch.Tensor, str, bool]:
"""
查找包含给定点的叶子节点并返回其信息以元组形式
:param query_point: 待查找的点
:return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf)
"""
# 如果当前节点是叶子节点,返回其信息
if self._is_leaf:
return self
else:
#logger.info(f"{self.bbox}, {self.param_key}, {True}")
return (self.bbox, self.param_key, True)
# 计算查询点所在的子节点索引
index = self.get_child_index(query_point)
# 遍历子节点列表,找到对应的子节点
for i, child_node in enumerate(self.child_nodes):
if i == index and child_node is not None:
# 递归调用子节点的 find_leaf 方法
result = child_node.find_leaf(query_point)
# 确保返回值是一个元组
assert isinstance(result, tuple), f"Unexpected return type: {type(result)}"
return result
# 如果找不到有效的子节点,抛出异常
raise IndexError(f"Invalid child node index: {index}")
'''
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].find_leaf(query_point)
@ -162,7 +182,9 @@ class OctreeNode(nn.Module):
f"Depth info: {self.max_depth}"
)
raise e
'''
'''
def get_feature_vector(self, query_point:torch.Tensor):
"""
预测给定点的 SDF
@ -188,7 +210,7 @@ class OctreeNode(nn.Module):
f"Depth info: {self.max_depth}"
)
raise e
'''

14
brep2sdf/train.py

@ -79,7 +79,11 @@ class Trainer:
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path) and not args.force_reprocess:
try:
self.data = load_brep_file(data_path)
except Exception as e:
logger.error(f"fail to load {data_path}, {str(e)}")
raise e
if args.use_normal and self.data.get("surf_pnt_normals", None) is None:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal)
else:
@ -203,7 +207,7 @@ class Trainer:
total_loss += loss.item()
# 记录训练进度
logger.info(f'Train Epoch: {epoch}]\t'
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {loss.item():.6f}')
return total_loss
@ -274,11 +278,9 @@ class Trainer:
def _tracing_model(self):
"""保存模型"""
self.model.eval()
example_input = torch.rand(10, 3, device=self.device)
# 3. 在no_grad上下文中执行追踪
with torch.no_grad():
traced_model = torch.jit.trace(self.model, example_input)
torch.jit.save(traced_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
# 确保模型中的所有逻辑都兼容 TorchScript
scripted_model = torch.jit.script(self.model)
torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""

Loading…
Cancel
Save