From a7541c9da3fbfdd0c632c4c6f1278a7558b309d3 Mon Sep 17 00:00:00 2001 From: mckay Date: Thu, 3 Apr 2025 16:42:12 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=86=99=E6=B3=95=EF=BC=8C?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=20torch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/octree.py | 221 ++++++++++++++++++++++-------------- brep2sdf/train.py | 37 +++--- 2 files changed, 154 insertions(+), 104 deletions(-) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 4816bac..fe63543 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -69,24 +69,37 @@ class OctreeNode: def subdivide(self): - min_x, min_y, min_z, max_x, max_y, max_z = self.bbox - + #min_x, min_y, min_z, max_x, max_y, max_z = self.bbox + # 使用索引操作替代解包 + min_coords = self.bbox[:3] # [min_x, min_y, min_z] + max_coords = self.bbox[3:] # [max_x, max_y, max_z] + # 计算中间点 - mid_x = (min_x + max_x) / 2 - mid_y = (min_y + max_y) / 2 - mid_z = (min_z + max_z) / 2 + mid_coords = (min_coords + max_coords) / 2 - # 生成8个子包围盒 - child_bboxes = torch.tensor([ - [min_x, min_y, min_z, mid_x, mid_y, mid_z], # 前下左 - [mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右 - [min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左 - [mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右 - [min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左 - [mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右 - [min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左 - [mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右 - ], dtype=torch.float32, device=OctreeNode.device) + # 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z + min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2] + mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2] + max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2] + + # 生成 8 个子包围盒 + child_bboxes = torch.stack([ + torch.cat([min_coords, mid_coords]), # 前下左 + torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device), + torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右 + torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device), + torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左 + torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device), + torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右 + torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device), + torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左 + torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device), + torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右 + torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device), + torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左 + torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device), + torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右 + ]) # 为每个子包围盒创建子节点,并分配相交的面 self.children = [] @@ -98,43 +111,37 @@ class OctreeNode: if bbox_intersect(bbox, face_bbox): intersecting_faces.append(face_idx) #print(f"{bbox}: {intersecting_faces}") - if intersecting_faces: - child_node = OctreeNode( - bbox=bbox, - face_indices=np.array(intersecting_faces), - max_depth=self.max_depth - 1 - ) - child_node.conduct_tree() - self.children.append(child_node) + + child_node = OctreeNode( + bbox=bbox, + face_indices=np.array(intersecting_faces), + max_depth=self.max_depth - 1 + ) + child_node.conduct_tree() + self.children.append(child_node) self._is_leaf = False - def get_child_index(self, query_point:torch.Tensor) -> int: + def get_child_index(self, query_point: torch.Tensor) -> int: """ 计算点所在子节点的索引 - :param point: 待检查的点,格式为 (x, y, z) + :param query_point: 待检查的点,格式为 (x, y, z) :return: 子节点的索引,范围从 0 到 7 """ - #print(query_point) - x, y, z = query_point - #logger.info(f"query_point: {query_point}") - #logger.info(f"box: {self.bbox}") - min_x, min_y, min_z, max_x, max_y, max_z = self.bbox - + # 确保 query_point 和 bbox 在同一设备上 + query_point = query_point.to(self.bbox.device) + + # 提取 bbox 的最小和最大坐标 + min_coords = self.bbox[:3] # [min_x, min_y, min_z] + max_coords = self.bbox[3:] # [max_x, max_y, max_z] + # 计算中间点 - mid_x = (min_x + max_x) / 2 - mid_y = (min_y + max_y) / 2 - mid_z = (min_z + max_z) / 2 - - index = 0 - if x >= mid_x: # 修正变量名 - index += 1 - if y >= mid_y: # 修正变量名 - index += 2 - if z >= mid_z: # 修正变量名 - index += 4 - #logger.info(f"index: {index}") - return index + mid_coords = (min_coords + max_coords) / 2 + + # 使用布尔比较结果计算索引 + index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum() + + return index.unsqueeze(0) def get_feature_vector(self, query_point:torch.Tensor): """ @@ -150,58 +157,59 @@ class OctreeNode: else: index = self.get_child_index(query_point) try: - if index < 0 or index >= len(self.children): - raise IndexError( - f"Child index {index} out of range (0-{len(self.children)-1}) " - f"for query point {query_point.cpu().numpy().tolist()}. " - f"Node bbox: {self.bbox.cpu().numpy().tolist()}" - f"dept info: {self.max_depth}" - ) + # 直接访问子节点,不进行显式检查 return self.children[index].get_feature_vector(query_point) except IndexError as e: - logger.error(str(e)) - raise e - - def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor: - """ - 使用三线性插值从补丁特征体积中获取查询点的特征向量。 - - :param query_point: 查询点的位置坐标 - :return: 插值后的特征向量 - """ - """三线性插值""" + # 记录错误日志并重新抛出异常 + logger.error( + f"Error accessing child node: {e}. " + f"Query point: {query_point.cpu().numpy().tolist()}, " + f"Node bbox: {self.bbox.cpu().numpy().tolist()}, " + f"Depth info: {self.max_depth}" + ) + raise e + + def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor: """ 实现三线性插值 + :param query_point: 待插值的点,格式为 (x, y, z) + :return: 插值结果,形状为 (D,) """ - # 获取包围盒的边界 - min_x, min_y, min_z, max_x, max_y, max_z = self.bbox - + # 确保 query_point 和 bbox 在同一设备上 + #query_point = query_point.to(self.bbox.device) + + # 获取包围盒的最小和最大坐标 + min_coords = self.bbox[:3] # [min_x, min_y, min_z] + max_coords = self.bbox[3:] # [max_x, max_y, max_z] + # 计算归一化坐标 - x = (query_point[0] - min_x) / (max_x - min_x) - y = (query_point[1] - min_y) / (max_y - min_y) - z = (query_point[2] - min_z) / (max_z - min_z) + normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误 + x, y, z = normalized_coords.unbind(dim=-1) + # 使用torch.stack避免Python标量转换 + wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分 + wy = torch.stack([1 - y, y], dim=-1) + wz = torch.stack([1 - z, z], dim=-1) + # 获取8个顶点的特征向量 - c000 = self.patch_feature_volume[0] - c100 = self.patch_feature_volume[1] - c010 = self.patch_feature_volume[2] - c110 = self.patch_feature_volume[3] - c001 = self.patch_feature_volume[4] - c101 = self.patch_feature_volume[5] - c011 = self.patch_feature_volume[6] - c111 = self.patch_feature_volume[7] - + c = self.patch_feature_volume # 形状为 (8, D) + # 执行三线性插值 - c00 = c000 * (1 - x) + c100 * x - c01 = c001 * (1 - x) + c101 * x - c10 = c010 * (1 - x) + c110 * x - c11 = c011 * (1 - x) + c111 * x - - c0 = c00 * (1 - y) + c10 * y - c1 = c01 * (1 - y) + c11 * y - - return c0 * (1 - z) + c1 * z + # 先对 x 轴插值 + c00 = c[0] * wx[0] + c[1] * wx[1] + c01 = c[2] * wx[0] + c[3] * wx[1] + c10 = c[4] * wx[0] + c[5] * wx[1] + c11 = c[6] * wx[0] + c[7] * wx[1] + + # 再对 y 轴插值 + c0 = c00 * wy[0] + c10 * wy[1] + c1 = c01 * wy[0] + c11 * wy[1] + + # 最后对 z 轴插值 + result = c0 * wz[0] + c1 * wz[1] + + return result def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: @@ -228,4 +236,41 @@ class OctreeNode: # 递归打印子节点 for i, child in enumerate(self.children): print(f"{indent} Child {i}:") - child.print_tree(depth + 1, max_print_depth) \ No newline at end of file + child.print_tree(depth + 1, max_print_depth) + +# 保存 + def state_dict(self): + """返回节点及其子树的state_dict""" + state = { + 'bbox': self.bbox, + 'max_depth': self.max_depth, + 'face_indices': self.face_indices, + 'is_leaf': self._is_leaf + } + + if self._is_leaf: + state['patch_feature_volume'] = self.patch_feature_volume + else: + state['children'] = [child.state_dict() for child in self.children] + + return state + + def load_state_dict(self, state_dict): + """从state_dict加载节点状态""" + self.bbox = state_dict['bbox'] + self.max_depth = state_dict['max_depth'] + self.face_indices = state_dict['face_indices'] + self._is_leaf = state_dict['is_leaf'] + + if self._is_leaf: + self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume']) + else: + self.children = [] + for child_state in state_dict['children']: + child = OctreeNode( + bbox=child_state['bbox'], + face_indices=child_state['face_indices'], + max_depth=child_state['max_depth'] + ) + child.load_state_dict(child_state) + self.children.append(child) \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 609eb1b..00be3e0 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -35,7 +35,7 @@ class Trainer: self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.model_name = os.path.basename(input_step).replace(".step", "") + self.model_name = os.path.basename(input_step).split('_')[0] self.base_name = self.model_name + ".xyz" data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) if os.path.exists(data_path): @@ -110,6 +110,7 @@ class Trainer: # 获取数据并移动到设备 points = self.sdf_data[:,0:3] + points.requires_grad_(True) gt_sdf = self.sdf_data[:,3] # 前向传播 @@ -180,24 +181,28 @@ class Trainer: total_time = time.time() - start_time logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') + self._tracing_model() - def _save_model(self, epoch: int, val_loss: float): - """保存最佳模型""" - save_path = os.path.join( - self.config.train.checkpoint_dir, - self.config.train.best_model_name.format( - model_name=config.train.model_name - ) - ) - torch.save({ - 'epoch': epoch, - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'loss': val_loss, - 'config': self.config - }, save_path) + def _tracing_model(self): + """保存模型""" + self.model.eval() + example_input = torch.rand(10, 3, device=self.device) + # 3. 在no_grad上下文中执行追踪 + with torch.no_grad(): + traced_model = torch.jit.trace(self.model, example_input) + torch.jit.save(traced_model, f"{self.model_name}.pt") + def _load_checkpoint(self, checkpoint_path): + """从检查点恢复训练状态""" + checkpoint = torch.load(checkpoint_path) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + start_epoch = checkpoint['epoch'] + 1 # 从下一轮开始 + best_loss = checkpoint['loss'] + logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}") + return start_epoch, best_loss + def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" checkpoint_dir = os.path.join(