From 1e4c36040312b9bed1e239bed501e3b1433c9293 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 3 Apr 2025 22:37:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E5=8A=A0=E8=BD=BD=E5=92=8C?= =?UTF-8?q?=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 126 +++++++++++++++++++++++++++++- brep2sdf/networks/network.py | 8 +- brep2sdf/networks/octree.py | 145 ++++++++++++++--------------------- brep2sdf/train.py | 55 +++++++++---- 4 files changed, 222 insertions(+), 112 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index c433a29..02008d8 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -8,7 +8,7 @@ from .octree import OctreeNode from brep2sdf.config.default_config import get_default_config from brep2sdf.utils.logger import logger import numpy as np - +''' class Encoder: def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64): """ @@ -59,11 +59,129 @@ class Encoder: +''' +class Encoder(nn.Module): + def __init__(self, octree: OctreeNode, feature_dim: int = 32): + """ + 分离后的编码器,接收预构建的八叉树 + + 参数: + octree: 预构建的八叉树结构 + feature_dim: 特征维度 + """ + super().__init__() + self.octree = octree + self.feature_dim = feature_dim + + # 为所有叶子节点注册可学习参数 + self._init_parameters() + + def _init_parameters(self): + """为所有叶子节点初始化特征参数""" + # 使用字典保存所有参数,避免动态属性 + self._leaf_parameters = nn.ParameterDict() + + # 递归遍历树结构 + def _register_params(node, path=""): + #logger.debug(node.is_leaf()) + if node.is_leaf(): + param_name = f"leaf_{path}" + self._leaf_parameters[param_name] = nn.Parameter( + torch.randn(8, self.feature_dim) # 8个顶点的特征 + ) + node.set_param_key(param_name) # 为节点存储参数键 + #logger.debug(param_name) + #logger.debug(node.param_key) + else: + for i, child in enumerate(node.child_nodes): + _register_params(child, f"{path}_{i}") + + _register_params(self.octree, "root") + + def get_leaf_parameter(self, node): + """获取叶子节点的特征参数""" + return self._leaf_parameters[node.param_key] + + def forward(self, query_points: torch.Tensor) -> torch.Tensor: + """ + 前向传播,处理批量查询点 + + 参数: + query_points: 查询点的位置坐标,形状为(batch_size, 3) + 返回: + feature_vectors: 查询点的特征向量,形状为(batch_size, feature_dim) + """ + batch_size = query_points.shape[0] + features = [] + + for i in range(batch_size): + # 1. 在八叉树中查找包含该点的叶子节点 + leaf_node = self.octree.find_leaf(query_points[i]) + #logger.debug(leaf_node.param_key) + + # 2. 获取该节点的特征参数 + bbox = leaf_node.bbox + node_features = self.get_leaf_parameter(leaf_node) + + # 3. 使用三线性插值计算特征 + # (这里需要实现你的插值逻辑) + interpolated = self.trilinear_interpolation( + query_points[i], bbox, node_features) + + features.append(interpolated) + + return torch.stack(features, dim=0) + + def trilinear_interpolation(self, query_point: torch.Tensor, bbox, features) -> torch.Tensor: + """ + 实现三线性插值 + :param query_point: 待插值的点,格式为 (x, y, z) + :return: 插值结果,形状为 (D,) + """ + # 确保 query_point 和 bbox 在同一设备上 + #query_point = query_point.to(self.bbox.device) + # 获取包围盒的最小和最大坐标 + min_coords = bbox[:3] # [min_x, min_y, min_z] + max_coords = bbox[3:] # [max_x, max_y, max_z] - - - + # 计算归一化坐标 + normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误 + x, y, z = normalized_coords.unbind(dim=-1) + + # 使用torch.stack避免Python标量转换 + wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分 + wy = torch.stack([1 - y, y], dim=-1) + wz = torch.stack([1 - z, z], dim=-1) + + # 获取8个顶点的特征向量 + c = features # 形状为 (8, D) + + # 执行三线性插值 + # 先对 x 轴插值 + c00 = c[0] * wx[0] + c[1] * wx[1] + c01 = c[2] * wx[0] + c[3] * wx[1] + c10 = c[4] * wx[0] + c[5] * wx[1] + c11 = c[6] * wx[0] + c[7] * wx[1] + + # 再对 y 轴插值 + c0 = c00 * wy[0] + c10 * wy[1] + c1 = c01 * wy[0] + c11 * wy[1] + + # 最后对 z 轴插值 + result = c0 * wz[0] + c1 * wz[1] + + return result + + def to(self, device): + super().to(device) + def _move_node(node): + if isinstance(node.bbox, torch.Tensor): + node.bbox = node.bbox.to(device) + for child in node.children: + _move_node(child) + _move_node(self.octree.root) + return self diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index fd0de8a..7477418 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -53,9 +53,7 @@ from .decoder import Decoder class Net(nn.Module): def __init__(self, - surf_bbox, - origin_bbox, - max_depth=4, + octree, feature_dim=64, decoder_input_dim=64, decoder_output_dim=1, @@ -68,9 +66,7 @@ class Net(nn.Module): # 初始化 Encoder self.encoder = Encoder( - surf_bbox=surf_bbox, # 使用传入的bbox作为表面包围盒 - origin_bbox=origin_bbox, # 使用相同的bbox作为原点包围盒 - max_depth=max_depth, + octree=octree, feature_dim=feature_dim ) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index fe63543..c0d2444 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -28,21 +28,20 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool: # 向量化比较 return torch.all((max1 >= min2) & (max2 >= min1)) -class OctreeNode: - feature_dim=None +class OctreeNode(nn.Module): device=None surf_bbox = None - def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox:torch.Tensor = None): + def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None): + super().__init__() self.bbox = bbox # 节点的边界框 self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点 - self.children: List['OctreeNode'] = [] # 子节点列表 + self.child_nodes: List['OctreeNode'] = [] # 子节点列表 self.face_indices = face_indices - self.patch_feature_volume = None # 补丁特征体积,only leaf has + self.param_key = None + #self.patch_feature_volume = None # 补丁特征体积,only leaf has self._is_leaf = True #print(f"box shape: {self.bbox.shape}") - if feature_dim is not None: - OctreeNode.feature_dim = feature_dim if surf_bbox is not None: if not isinstance(surf_bbox, torch.Tensor): raise TypeError( @@ -56,13 +55,15 @@ class OctreeNode: OctreeNode.device = bbox.device def is_leaf(self): - # Check if self.children is None before calling len() + # Check if self.child——nodes is None before calling len() return self._is_leaf + def set_param_key(self, k): + self.param_key = k + def conduct_tree(self): if self.max_depth <= 0 or self.face_indices.shape[0] <= 2: # 达到最大深度 or 一个单元格至多只有两个面 - self.patch_feature_volume = nn.Parameter(torch.randn(8, OctreeNode.feature_dim, device=OctreeNode.device)) return self.subdivide() @@ -102,7 +103,7 @@ class OctreeNode: ]) # 为每个子包围盒创建子节点,并分配相交的面 - self.children = [] + self.child_nodes = [] for bbox in child_bboxes: # 找到与子包围盒相交的面 intersecting_faces = [] @@ -118,7 +119,7 @@ class OctreeNode: max_depth=self.max_depth - 1 ) child_node.conduct_tree() - self.children.append(child_node) + self.child_nodes.append(child_node) self._is_leaf = False @@ -143,6 +144,25 @@ class OctreeNode: return index.unsqueeze(0) + def find_leaf(self, query_point:torch.Tensor): + # 从根节点开始递归查找包含该点的叶子节点 + if self._is_leaf: + return self + else: + index = self.get_child_index(query_point) + try: + # 直接访问子节点,不进行显式检查 + return self.child_nodes[index].find_leaf(query_point) + except IndexError as e: + # 记录错误日志并重新抛出异常 + logger.error( + f"Error accessing child node: {e}. " + f"Query point: {query_point.cpu().numpy().tolist()}, " + f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " + f"Depth info: {self.max_depth}" + ) + raise e + def get_feature_vector(self, query_point:torch.Tensor): """ 预测给定点的 SDF 值 @@ -158,7 +178,7 @@ class OctreeNode: index = self.get_child_index(query_point) try: # 直接访问子节点,不进行显式检查 - return self.children[index].get_feature_vector(query_point) + return self.child_nodes[index].get_feature_vector(query_point) except IndexError as e: # 记录错误日志并重新抛出异常 logger.error( @@ -170,46 +190,7 @@ class OctreeNode: raise e - def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: - """ - 实现三线性插值 - :param query_point: 待插值的点,格式为 (x, y, z) - :return: 插值结果,形状为 (D,) - """ - # 确保 query_point 和 bbox 在同一设备上 - #query_point = query_point.to(self.bbox.device) - - # 获取包围盒的最小和最大坐标 - min_coords = self.bbox[:3] # [min_x, min_y, min_z] - max_coords = self.bbox[3:] # [max_x, max_y, max_z] - - # 计算归一化坐标 - normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误 - x, y, z = normalized_coords.unbind(dim=-1) - - # 使用torch.stack避免Python标量转换 - wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分 - wy = torch.stack([1 - y, y], dim=-1) - wz = torch.stack([1 - z, z], dim=-1) - - # 获取8个顶点的特征向量 - c = self.patch_feature_volume # 形状为 (8, D) - - # 执行三线性插值 - # 先对 x 轴插值 - c00 = c[0] * wx[0] + c[1] * wx[1] - c01 = c[2] * wx[0] + c[3] * wx[1] - c10 = c[4] * wx[0] + c[5] * wx[1] - c11 = c[6] * wx[0] + c[7] * wx[1] - - # 再对 y 轴插值 - c0 = c00 * wy[0] + c10 * wy[1] - c1 = c01 * wy[0] + c11 * wy[1] - - # 最后对 z 轴插值 - result = c0 * wz[0] + c1 * wz[1] - - return result + def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: @@ -231,46 +212,32 @@ class OctreeNode: # 打印面片信息(如果有) if self.face_indices is not None: print(f"{indent} Face indices: {self.face_indices.tolist()}") - print(f"{indent} len children: {len(self.children)}") + print(f"{indent} len child_nodes: {len(self.child_nodes)}") # 递归打印子节点 - for i, child in enumerate(self.children): + for i, child in enumerate(self.child_nodes): print(f"{indent} Child {i}:") child.print_tree(depth + 1, max_print_depth) -# 保存 - def state_dict(self): - """返回节点及其子树的state_dict""" - state = { - 'bbox': self.bbox, - 'max_depth': self.max_depth, - 'face_indices': self.face_indices, - 'is_leaf': self._is_leaf - } - - if self._is_leaf: - state['patch_feature_volume'] = self.patch_feature_volume - else: - state['children'] = [child.state_dict() for child in self.children] - - return state - - def load_state_dict(self, state_dict): - """从state_dict加载节点状态""" - self.bbox = state_dict['bbox'] - self.max_depth = state_dict['max_depth'] - self.face_indices = state_dict['face_indices'] - self._is_leaf = state_dict['is_leaf'] + def __getstate__(self): + """支持pickle序列化""" + return self._serialize_node(self) + + def __setstate__(self, state): + """支持pickle反序列化""" + self = self._deserialize_node(state) - if self._is_leaf: - self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume']) - else: - self.children = [] - for child_state in state_dict['children']: - child = OctreeNode( - bbox=child_state['bbox'], - face_indices=child_state['face_indices'], - max_depth=child_state['max_depth'] - ) - child.load_state_dict(child_state) - self.children.append(child) \ No newline at end of file + def _serialize_node(self, node): + return { + 'bbox': node.bbox, + 'is_leaf': node._is_leaf, + 'child_nodes': [self._serialize_node(c) for c in node.child_nodes], + 'param_key': node.param_key + } + + def _deserialize_node(self, data): + node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建 + node._is_leaf = data['is_leaf'] + node.param_key = data['param_key'] + node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']] + return node \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 00be3e0..4031e46 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -1,4 +1,5 @@ import torch +from torch.serialization import add_safe_globals import torch.optim as optim import time import os @@ -9,6 +10,7 @@ from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,load_sdf_file from brep2sdf.data.pre_process import process_single_step from brep2sdf.networks.network import Net +from brep2sdf.networks.octree import OctreeNode from brep2sdf.utils.logger import logger def prepare_sdf_data(surf_data, max_points=100000, device='cuda'): @@ -62,10 +64,13 @@ class Trainer: dtype=torch.float32, device=self.device ) - bbox = self._calculate_global_bbox(surf_bbox) + + + self.build_tree(surf_bbox=surf_bbox, max_depth=4) + + self.model = Net( - surf_bbox=surf_bbox, - origin_bbox=bbox, + octree=self.root, feature_dim=64 ).to(self.device) @@ -77,7 +82,21 @@ class Trainer: ) - + def build_tree(self,surf_bbox, max_depth=6): + num_faces = surf_bbox.shape[0] + bbox = self._calculate_global_bbox(surf_bbox) + self.root = OctreeNode( + bbox=bbox, + face_indices=np.arange(num_faces), # 初始包含所有面 + max_depth=max_depth, + surf_bbox=surf_bbox + ) + #print(surf_bbox) + logger.info("starting octree conduction") + self.root.conduct_tree() + logger.info("complete octree conduction") + #self.root.print_tree(0) + def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ 计算整个数据集的全局边界框,综合考虑表面包围盒和采样点 @@ -102,6 +121,7 @@ class Trainer: # 返回合并后的边界框 return torch.cat([global_min, global_max]) + def train_epoch(self, epoch: int) -> float: self.model.train() @@ -153,7 +173,7 @@ class Trainer: best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() - + """ for epoch in range(1, self.config.train.num_epochs + 1): # 训练一个epoch train_loss = self.train_epoch(epoch) @@ -182,7 +202,18 @@ class Trainer: logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') self._tracing_model() - + """ + self.test_load() + + def test_load(self): + model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt") + model.eval() + logger.debug(model) + example_input = torch.rand(10, 3, device=self.device) + #logger.debug(model.encoder.octree.bbox) + logger.debug(f"points: {example_input}") + sdfs= model(example_input) + logger.debug(f"sdfs:{sdfs}") def _tracing_model(self): """保存模型""" @@ -195,13 +226,8 @@ class Trainer: def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态""" - checkpoint = torch.load(checkpoint_path) - self.model.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - start_epoch = checkpoint['epoch'] + 1 # 从下一轮开始 - best_loss = checkpoint['loss'] - logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}") - return start_epoch, best_loss + model = torch.load(checkpoint_path) + return model def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" @@ -211,6 +237,7 @@ class Trainer: ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth") + ''' torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), @@ -218,6 +245,8 @@ class Trainer: 'loss': train_loss, 'config': self.config }, checkpoint_path) + ''' + torch.save(self.model,checkpoint_path) def main(): # 这里需要初始化配置