diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 1018b45..8dd4537 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -9,6 +9,7 @@ class ModelConfig: embed_dim: int = 768 # 3 的 倍数 latent_dim: int = 32 + octree_max_depth = 6 # 点云采样配置 num_surf_points: int = 64 # 每个面采样点数 num_edge_points: int = 8 # 每条边采样点数 @@ -48,7 +49,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 1000 + num_epochs: int = 1 learning_rate: float = 0.001 min_lr: float = 1e-5 weight_decay: float = 0.01 diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index f7d4f9b..dd0de84 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -141,7 +141,6 @@ def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'): return torch.tensor(sdf_array, dtype=torch.float32, device=device) - def print_data_distribution(data: torch.Tensor) -> None: """打印数据分布统计信息 diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 088e9b0..9b82050 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -56,6 +56,7 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor: return padded_tensor + def normalize(surfs, edges, corners): """ 将CAD模型归一化到单位立方体空间 diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 7e8d142..660b895 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -88,7 +88,34 @@ class Decoder(nn.Module): return f_i + @torch.jit.export + def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor: + ''' + :param feature_matrix: 形状为(S, D) 的特征矩阵 + S: 采样数量 + D: 特征维度 + :return: + f: 各patch的SDF值 (S) + ''' + # 直接使用输入的特征矩阵,因为形状已经是 (S, D) + x = feature_matrix + + for layer in range(0, self.sdf_layers - 1): + lin = getattr(self, "sdf_" + str(layer)) + if layer in self.skip_in: + x = torch.cat([x, x], -1) / np.sqrt(2) + x = lin(x) + if layer < self.sdf_layers - 2: + x = self.activation(x) + + output_value = x # 所有 f 的值 + + # 调整输出形状为 (S) + f = output_value.squeeze(-1) + + return f +""" # 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h #注意考虑如何批量处理 (B, P) 和 [csg tree] class CSGCombiner: @@ -96,7 +123,8 @@ class CSGCombiner: self.flag_convex = flag_convex self.rho = rho - def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor: + def forward(self, f_i: torch.Tensor, csg_tree + ) -> torch.Tensor: ''' :param f_i: 形状为 (B, P) 的各patch SDF值 :param csg_tree: CSG树结构 @@ -216,4 +244,5 @@ def test_csg_combiner(): print(f"rho={rho}:", h_soft) if __name__ == "__main__": - test_csg_combiner() \ No newline at end of file + test_csg_combiner() + """ \ No newline at end of file diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 03d2c53..653513a 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -75,18 +75,34 @@ class Encoder(nn.Module): current_indices = volume_indices[:, k] # 遍历所有存在的volume - for vol_id in range(len(self.feature_volumes)): + for vol_id, volume in enumerate(self.feature_volumes): # 创建掩码 (B,) mask = (current_indices == vol_id) if mask.any(): # 获取对应volume的特征 (M, D) - features = self.feature_volumes[vol_id](query_points[mask]) + features = volume.forward(query_points[mask]) all_features[mask, k] = features return all_features + @torch.jit.export + def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor: + """ + 处理表面采样点的特征提取 + 参数: + surf_points: 表面采样点 (S, 3), S: #sampled point per feature_volumes. + 返回: + 特征张量 (S, D) + """ + # 使用枚举遍历 feature_volumes,避免直接索引 + for idx, volume in enumerate(self.feature_volumes): + if idx == patch_id: + return volume.forward(surf_points) + return torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device) + + return features - def _optimized_trilinear(self, points, bboxes, features): + def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor: """优化后的向量化三线性插值""" # 添加显式类型转换确保计算稳定性 min_coords = bboxes[..., :3].to(torch.float32) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 15e0c64..e2a45b6 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -49,7 +49,7 @@ import torch import torch.nn as nn from torch.autograd import grad from .encoder import Encoder -from .decoder import Decoder, CSGCombiner +from .decoder import Decoder from brep2sdf.utils.logger import logger class Net(nn.Module): @@ -82,8 +82,9 @@ class Net(nn.Module): beta=100 ) - self.csg_combiner = CSGCombiner(flag_convex=True) + #self.csg_combiner = CSGCombiner(flag_convex=True) + @torch.jit.export def forward(self, query_points): """ 前向传播 @@ -94,18 +95,55 @@ class Net(nn.Module): output: 解码后的输出结果 """ # 批量查询所有点的索引和bbox - _,face_indices,csg_trees = self.octree_module.forward(query_points) + _,face_indices_mask,operator = self.octree_module.forward(query_points) # 编码 - feature_vectors = self.encoder.forward(query_points,face_indices) - #print("feature_vector:", feature_vectors.requires_grad) + feature_vectors = self.encoder.forward(query_points,face_indices_mask) + print("feature_vector:", feature_vectors.shape) # 解码 logger.gpu_memory_stats("encoder farward后") - f_i = self.decoder(feature_vectors) + f_i = self.decoder(feature_vectors) # (B, P) logger.gpu_memory_stats("decoder farward后") - output = self.csg_combiner.forward(f_i, csg_trees) + + + output = f_i[:, 0] + # 提取有效值并填充到固定大小 (B, max_patches) + padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches) + for i in range(f_i.shape[0]): + sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P + num_valid = min(len(sample_valid_values), 2) + padded_f_i[i, :num_valid] = sample_valid_values[:num_valid] + + + # 找到需要组合的行 + mask_concave = (operator == 0) + mask_convex = (operator == 1) + + # 对 operator == 0 的样本取最大值 + if mask_concave.any(): + output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values + + # 对 operator == 1 的样本取最小值 + if mask_convex.any(): + output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values + + logger.gpu_memory_stats("combine后") return output + @torch.jit.export + def forward_training_volumes(self, surf_points, patch_id:int): + """ + only surf sampled points + surf_points (P, S): + return (P, S) + """ + feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) + f_i = self.decoder.forward_training_volumes(feature_mat) + return f_i.squeeze() + + + + def gradient(inputs, outputs): d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 18a46ac..1726914 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple,List, Optional import torch import torch.nn as nn @@ -54,7 +54,7 @@ def bbox_intersect( surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor, - surf_points: torch.Tensor = None + surf_points: Optional[torch.Tensor]=None ) -> torch.Tensor: ''' args: @@ -69,15 +69,15 @@ def bbox_intersect( # 初始化全为 False 的结果掩码 [8, B] B = surf_bboxes.size(0) result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device) - logger.debug(result_mask.shape) - logger.debug(indices.shape) + #logger.debug(result_mask.shape) + #logger.debug(indices.shape) # 提取选中的边界框 selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] - logger.debug(selected_bboxes.shape) + #logger.debug(selected_bboxes.shape) # 计算子包围盒与选中包围盒的交集 intersect_mask = torch.all( (max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] @@ -109,11 +109,11 @@ def bbox_intersect( # 合并交集条件和点云条件 result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask - logger.debug(result_mask.shape) + #logger.debug(result_mask.shape) return result_mask class OctreeNode(nn.Module): - def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None): + def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: Optional[torch.Tensor]=None, patch_graph: Optional[PatchGraph] = None,surf_ncs:Optional[np.ndarray] = None,device:Optional[torch.device]=None): super().__init__() self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 改为普通张量属性 @@ -185,7 +185,7 @@ class OctreeNode(nn.Module): queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) current_idx += 8 - def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool: + def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool: """判断节点是否需要分裂""" # 检查是否达到最大深度 if current_idx + 8 >= max_node: @@ -229,7 +229,7 @@ class OctreeNode(nn.Module): """ 修改后的查找叶子节点方法,返回face indices :param query_points: 待查找的点,形状为 (3,) - :return: (bbox, param_key, face_indices, is_leaf) + :return: (bbox, face_indices, is_leaf) """ # 确保输入是单个点 if query_points.dim() != 1 or query_points.shape[0] != 3: @@ -248,10 +248,9 @@ class OctreeNode(nn.Module): #logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") if parent_idx == -1: # 根节点没有父节点,返回根节点的信息 - #logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") return ( self.node_bboxes[current_idx], - None, # 新增返回face indices + torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices False ) return ( @@ -280,7 +279,7 @@ class OctreeNode(nn.Module): iteration += 1 # 如果达到最大迭代次数,返回当前节点的信息 - return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item()) + return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item()) @torch.jit.export def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: @@ -288,23 +287,26 @@ class OctreeNode(nn.Module): mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2 return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1) - def forward(self, query_points): + def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: with torch.no_grad(): - bboxes, face_indices_mask, csg_trees = [], [], [] + bboxes: List[torch.Tensor] = [] + face_indices_mask: List[torch.Tensor] = [] + operator: List[int] = [] for point in query_points: - bbox, faces_mask, _ = self.find_leaf(point) + bbox, faces_mask, _ = self.find_leaf(point) bboxes.append(bbox) face_indices_mask.append(faces_mask) # 获取当前节点的CSG树结构 - csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None - csg_trees.append(csg_tree) # 保持原始列表结构 + #csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None + #csg_trees.append(csg_tree) # 保持原始列表结构 + operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1) return ( torch.stack(bboxes), torch.stack(face_indices_mask), - csg_trees # 直接返回列表,不转换为张量 + torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量 ) - def print_tree(self, max_print_depth: int = None) -> None: + def print_tree(self, max_print_depth: Optional[int] = None) -> None: """ 使用深度优先遍历 (DFS) 打印树结构,父子关系通过缩进体现。 @@ -353,36 +355,45 @@ class OctreeNode(nn.Module): dfs(0, 0) # 统一输出所有日志 - logger.debug("\n".join(log_lines)) + #logger.debug("\n".join(log_lines)) def __getstate__(self): """支持pickle序列化""" state = { - 'bbox': self.bbox, - 'node_bboxes': self.node_bboxes, - 'parent_indices': self.parent_indices, - 'child_indices': self.child_indices, - 'is_leaf_mask': self.is_leaf_mask, - 'face_indices': self.face_indices, - 'surf_bbox': self.surf_bbox, - 'patch_graph': self.patch_graph, + 'bbox': self.bbox.cpu(), # 转换为CPU张量 + 'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None, + 'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None, + 'child_indices': self.child_indices.cpu() if self.child_indices is not None else None, + 'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None, + 'all_face_indices': self.all_face_indices.cpu(), + 'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None, + 'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None, + 'surf_ncs': self.surf_ncs.cpu() if self.surf_ncs is not None else None, + 'patch_graph': self.patch_graph, # 假设PatchGraph支持序列化 'max_depth': self.max_depth, - '_is_leaf': self._is_leaf + 'device': str(self.device) # 保存设备信息 } return state def __setstate__(self, state): """支持pickle反序列化""" - self.bbox = state['bbox'] - self.node_bboxes = state['node_bboxes'] - self.parent_indices = state['parent_indices'] - self.child_indices = state['child_indices'] - self.is_leaf_mask = state['is_leaf_mask'] - self.face_indices = state['face_indices'] - self.surf_bbox = state['surf_bbox'] - self.patch_graph = state['patch_graph'] - self.max_depth = state['max_depth'] - self._is_leaf = state['_is_leaf'] + # 手动调用 __init__ 方法 + self.__init__( + bbox=state['bbox'], + face_indices=state['all_face_indices'].cpu().numpy(), + patch_graph=state['patch_graph'], + max_depth=state['max_depth'], + surf_bbox=state['surf_bbox'], + surf_ncs=state['surf_ncs'], + device=torch.device(state['device']) + ) + # 可以在这里设置其他不需要在 __init__ 中处理的属性 + self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None + self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None + self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None + self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None + self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None + def to(self, device=None, dtype=None, non_blocking=False): # 调用父类方法迁移基础参数 diff --git a/brep2sdf/networks/patch_graph.py b/brep2sdf/networks/patch_graph.py index 0239942..c9e47a8 100644 --- a/brep2sdf/networks/patch_graph.py +++ b/brep2sdf/networks/patch_graph.py @@ -71,7 +71,7 @@ class PatchGraph(nn.Module): return [] node_faces = node_faces_mask.nonzero() node_faces = node_faces.flatten().to('cpu').numpy() - logger.debug(f"node_faces: {node_faces}") + #logger.debug(f"node_faces: {node_faces}") node_set = set(node_faces) # 创建输入面片的集合用于快速查找 visited = set() csg_tree = [] @@ -89,6 +89,32 @@ class PatchGraph(nn.Module): csg_tree.extend(remaining) return csg_tree + + def get_operator(self, node_faces: torch.Tensor): + # node_faces: shape (<=2,) + # 返回 0: 凹边, 1: 凸边, + node_faces = node_faces.flatten().to(self.device) + num_faces = node_faces.numel() + if num_faces == 1: + # 这里设置凸边是因为 后续会补一个 f2 = inf, h = min(f1, f2) + # 因为 f2 = inf, 所以 h = f1 + return 1 # 只有一个面 + if num_faces > 2: + #logger.warning("get_operator 输入数量为{} > 2,将只取前两个面片进行处理".format(num_faces)) + node_faces = node_faces[:2] + #if self.edge_index is None or self.edge_type is None: + #logger.warning("edge_index 或 edge_type 未设置") + # 查找这两个面之间的边 + mask = ((self.edge_index[0] == node_faces[0]) & (self.edge_index[1] == node_faces[1])) | \ + ((self.edge_index[0] == node_faces[1]) & (self.edge_index[1] == node_faces[0])) + if not mask.any(): + #logger.warning("没有面可以用") + return 3 + edge_types = self.edge_type[mask] + # 如果有多条边,返回第一个 + return int(edge_types[0].item()) + + def is_clique(self, node_faces: torch.Tensor) -> bool: """检查给定面片集合是否构成完全图 @@ -150,7 +176,7 @@ class PatchGraph(nn.Module): @staticmethod def from_preprocessed_data( - surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 + surf_ncs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组 edge_types: np.ndarray, # 形状为(num_edges,)的int32数组 device: torch.device = None @@ -158,7 +184,7 @@ class PatchGraph(nn.Module): """从预处理的数据直接构建面片邻接图 参数: - surf_wcs: 世界坐标系下的曲面几何数据,形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 + surf_ncs: 归一化坐标系下的曲面几何数据,形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组 edgeFace_adj: 边-面邻接矩阵,形状为(num_edges, num_faces)的int32数组,1表示边与面相邻 edge_types: 边的类型数组,形状为(num_edges,)的int32数组,0表示凹边,1表示凸边 @@ -167,7 +193,7 @@ class PatchGraph(nn.Module): - edge_index: 形状为(2, num_edges*2)的torch.long张量,表示双向边的连接关系 - edge_type: 形状为(num_edges*2,)的torch.long张量,表示每条边的类型 """ - num_faces = len(surf_wcs) + num_faces = len(surf_ncs) graph = PatchGraph(num_faces,device) # 构建边的索引和类型 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index a665afb..d0caa33 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -104,7 +104,7 @@ class Trainer: # 构建面片邻接图 graph = PatchGraph.from_preprocessed_data( - surf_wcs=self.data['surf_wcs'], + surf_ncs=self.data['surf_ncs'], edgeFace_adj=self.data['edgeFace_adj'], edge_types=self.data['edge_types'], device='cuda' if args.octree_cuda else 'cpu' @@ -115,9 +115,15 @@ class Trainer: dtype=torch.float32, device=self.device ) - - self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6) - logger.gpu_memory_stats("数初始化后") + max_depth = config.model.octree_max_depth + if not args.force_reprocess: + if not self._load_octree(): + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) + elif self.root.max_depth != max_depth: + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) + else: + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) + logger.gpu_memory_stats("树初始化后") self.model = Net( octree=self.root, @@ -140,6 +146,7 @@ class Trainer: def build_tree(self,surf_bbox, graph, max_depth=9): + logger.info("开始构造八叉树...") num_faces = surf_bbox.shape[0] bbox = self._calculate_global_bbox(surf_bbox) self.root = OctreeNode( @@ -155,6 +162,7 @@ class Trainer: self.root.build_static_tree() logger.info("complete octree conduction") self.root.print_tree() + self._save_octree() def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ @@ -188,8 +196,80 @@ class Trainer: # # 返回合并后的边界框 # return torch.cat([global_min, global_max]) # return [-0.5,] # 这个是错误的 + def train_epoch_stage1(self, epoch: int): + total_loss = 0.0 # 初始化总损失 + for step, surf_points in enumerate(self.data['surf_ncs']): # 定义 step 变量 + points = torch.tensor(surf_points, device=self.device) + gt_sdf = torch.zeros(points.shape[0], device=self.device) + normals = None + if args.use_normal: + normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) + + # --- 准备模型输入,启用梯度 --- + points.requires_grad_(True) # 在检查之后启用梯度 + + # --- 前向传播 --- + self.optimizer.zero_grad() + pred_sdf = self.model.forward_training_volumes(points, step) + + if self.debug_mode: + # --- 检查前向传播的输出 --- + logger.gpu_memory_stats("前向传播后") + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + if args.use_normal: + loss, loss_details = self.loss_manager.compute_loss( + points, + normals, + gt_sdf, + pred_sdf + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + + if self.debug_mode: + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: + logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + + # --- 反向传播和优化 --- + try: + loss.backward() + # --- (推荐) 添加梯度裁剪 --- + # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 + + self.optimizer.step() + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + # --- 记录和累加损失 --- + current_loss = loss.item() + if not np.isfinite(current_loss): # 再次确认损失是有效的数值 + logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") + return float('inf') + + total_loss += current_loss + del loss + torch.cuda.empty_cache() + + # 记录训练进度 (只记录有效的损失) + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {current_loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + + return total_loss + - def train_epoch(self, epoch: int) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) @@ -347,7 +427,8 @@ class Trainer: for epoch in range(start_epoch, self.config.train.num_epochs + 1): # 训练一个epoch - train_loss = self.train_epoch(epoch) + train_loss = self.train_epoch_stage1(epoch) + #train_loss = self.train_epoch(epoch) # 验证 ''' @@ -445,6 +526,50 @@ class Trainer: except Exception as e: logger.error(f"加载checkpoint失败: {str(e)}") raise + + # ... existing code ... + + def _save_octree(self): + """ + 保存八叉树到文件。 + 八叉树保存路径基于模型名称和配置中的检查点目录。 + """ + checkpoint_dir = os.path.join( + self.config.train.checkpoint_dir, + self.model_name + ) + octree_path = os.path.join(checkpoint_dir, "octree.pth") + + try: + # 保存八叉树的根节点 + torch.save(self.root, octree_path) + logger.info(f"八叉树已保存到 {octree_path}") + except Exception as e: + logger.error(f"保存八叉树失败: {str(e)}") + + def _load_octree(self)->bool: + """ + 从文件加载八叉树。 + 尝试从基于模型名称和配置检查点目录的路径加载八叉树。 + """ + checkpoint_dir = os.path.join( + self.config.train.checkpoint_dir, + self.model_name + ) + octree_path = os.path.join(checkpoint_dir, "octree.pth") + + try: + if os.path.exists(octree_path): + # 加载八叉树的根节点 + self.root = torch.load(octree_path, weights_only=False) + logger.info(f"八叉树已从 {octree_path} 加载") + return True + else: + logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。") + except Exception as e: + logger.error(f"加载八叉树失败: {str(e)}") + return False + def main(): # 这里需要初始化配置 config = get_default_config()