import torch import torch.optim as optim import time import os import numpy as np import argparse from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.sample import NormalPerPoint from brep2sdf.networks.learning_rate import LearningRateScheduler from brep2sdf.utils.logger import logger # 配置命令行参数 parser = argparse.ArgumentParser(description='STEP文件批量处理工具') parser.add_argument('-i', '--input', required=True, help='待处理 brep (.step) 路径') parser.add_argument( '--use-normal', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制采样点有法向量' ) parser.add_argument( '--only-zero-surface', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='只采样零表面点 SDF 训练' ) parser.add_argument( '--force-reprocess','-f', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制重新进行数据预处理,忽略缓存或已有结果' ) parser.add_argument( '--resume-checkpoint-path', type=str, default=None, help='从指定的checkpoint文件继续训练' ) parser.add_argument( '--octree-cuda', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='使用CUDA加速Octree构建' ) args = parser.parse_args() class Trainer: def __init__(self, config, input_step): logger.gpu_memory_stats("初始化开始") self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.debug_mode = config.train.debug_mode self.model_name = os.path.basename(input_step).split('_')[0] self.base_name = self.model_name + ".xyz" data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) if os.path.exists(data_path) and not args.force_reprocess: try: self.data = load_brep_file(data_path) except Exception as e: logger.error(f"fail to load {data_path}, {str(e)}") raise e if args.use_normal and self.data.get("surf_pnt_normals", None) is None: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) else: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) logger.gpu_memory_stats("数据预处理后") self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) # # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] # 准备表面点的SDF数据 surface_sdf_data = prepare_sdf_data( surfs, normals=self.data["surf_pnt_normals"], max_points=50000, device=self.device ) # 如果不是仅使用零表面,则合并采样点数据 if not args.only_zero_surface: # 加载采样点数据 sampled_sdf_data = torch.tensor( self.data['sampled_points_normals_sdf'], dtype=torch.float32, device=self.device ) # 合并表面点数据和采样点数据 self.sdf_data = torch.cat([surface_sdf_data, sampled_sdf_data], dim=0) else: self.sdf_data = surface_sdf_data print_data_distribution(self.sdf_data) logger.debug(self.sdf_data.shape) logger.gpu_memory_stats("SDF数据准备后") # 初始化数据集 #self.brep_data = load_brep_file(self.config.data.pkl_path) #logger.info( self.brep_data ) #self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) # 构建面片邻接图 graph = PatchGraph.from_preprocessed_data( surf_ncs=self.data['surf_ncs'], edgeFace_adj=self.data['edgeFace_adj'], edge_types=self.data['edge_types'], device='cuda' if args.octree_cuda else 'cpu' ) # 初始化网络 surf_bbox=torch.tensor( self.data['surf_bbox_ncs'], dtype=torch.float32, device=self.device ) max_depth = config.model.octree_max_depth if not args.force_reprocess: if not self._load_octree(): self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) elif self.root.max_depth != max_depth: self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) else: self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth) logger.gpu_memory_stats("树初始化后") self.model = Net( octree=self.root, volume_bboxs=surf_bbox, feature_dim=64 ).to(self.device) logger.gpu_memory_stats("模型初始化后") self.scheduler = LearningRateScheduler(config.train.learning_rate_schedule, config.train.weight_decay, self.model.parameters()) self.loss_manager = LossManager(ablation="none") logger.gpu_memory_stats("训练器初始化后") self.sampler = NormalPerPoint( global_sigma=0.1, # 全局采样标准差 local_sigma=0.01 # 局部采样标准差 ) logger.info(f"初始化完成,正在处理模型 {self.model_name}") def build_tree(self,surf_bbox, graph, max_depth=9): logger.info("开始构造八叉树...") num_faces = surf_bbox.shape[0] bbox = self._calculate_global_bbox(surf_bbox) self.root = OctreeNode( bbox=bbox, face_indices=np.arange(num_faces), # 初始包含所有面 patch_graph=graph, max_depth=max_depth, surf_bbox=surf_bbox, surf_ncs=self.data['surf_ncs'] ) #print(surf_bbox) logger.info("starting octree conduction") self.root.build_static_tree() logger.info("complete octree conduction") self.root.print_tree() self._save_octree() def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ 返回一个固定的全局边界框(单位立方体)。 参数: surf_bbox: (此参数在此实现中未使用) 返回: 形状为 (6,) 的Tensor,表示固定的边界框 [-0.5, -0.5, -0.5, 0.5, 0.5, 0.5] """ # 直接定义固定的单位立方体边界框 # 注意:确保张量在正确的设备上创建 fixed_bbox = torch.tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5], dtype=torch.float32) # 假设 self.device 存储了目标设备 logger.debug(f"使用固定的全局边界框: {fixed_bbox.cpu().numpy()}") return fixed_bbox # --- 旧的计算逻辑 (注释掉或删除) --- # # 验证输入 # if not isinstance(surf_bbox, torch.Tensor): # raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}") # if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: # raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}") # # # 计算表面包围盒的全局范围 # global_min = surf_bbox[:, :3].min(dim=0).values # global_max = surf_bbox[:, 3:].max(dim=0).values # # # 返回合并后的边界框 # return torch.cat([global_min, global_max]) # return [-0.5,] # 这个是错误的 def train_epoch_stage1_(self, epoch: int): total_loss = 0.0 total_loss_details = { "manifold": 0.0, "normals": 0.0, "eikonal": 0.0, "offsurface": 0.0 } accumulated_loss = 0.0 # 新增:用于累积多个step的loss # 新增:在每个epoch开始时清零梯度 self.optimizer.zero_grad() for step, surf_points in enumerate(self.data['surf_ncs']): mnfld_points = torch.tensor(surf_points, device=self.device) nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) normals = None if args.use_normal: normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) logger.debug(normals) # --- 准备模型输入,启用梯度 --- mnfld_points.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- self.optimizer.zero_grad() mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) if self.debug_mode: # --- 检查前向传播的输出 --- logger.print_tensor_stats("mnfld_pred",mnfld_pred) logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) logger.gpu_memory_stats("前向传播后") # --- 计算损失 --- try: if args.use_normal: loss, loss_details = self.loss_manager.compute_loss( mnfld_points, nonmnfld_pnts, normals, gt_sdf, mnfld_pred, nonmnfld_pred ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss_details = {} # 确保变量初始化 # 修改:累积loss而不是立即backward accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps current_loss = loss.item() total_loss += current_loss for key in total_loss_details: if key in loss_details: total_loss_details[key] += loss_details[key].item() # 新增:达到累积步数时执行反向传播 if (step + 1) % self.config.train.accumulation_steps == 0: # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 self.scheduler.step(accumulated_loss,epoch) # 记录日志保持不变 ... except Exception as loss_e: logger.error(f"Error in step {step}: {loss_e}") continue # --- 内存管理 --- del loss torch.cuda.empty_cache() # 新增:处理最后未达到累积步数的剩余loss if accumulated_loss != 0: # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 self.scheduler.step(accumulated_loss,epoch) # 计算并记录epoch损失 logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {total_loss:.6f}') logger.info(f"Loss Details: {total_loss_details}") return total_loss # 返回平均损失而非累计值 def train_epoch_stage1(self, epoch: int) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 if self.train_surf_ncs is None: logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") return float('inf') self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 batch_size = 8192 # 设置合适的batch大小 # 数据处理 # manfld _mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 _normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 _gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 # 检查是否需要重新计算缓存 if epoch % 10 == 1 or self.cached_train_data is None: # 计算流形点的掩码和操作符 # 生成非流形点 _psdf_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals,local_sigma=0.001) _nonmnfld_pnts = self.sampler.get_points(_mnfld_pnts, local_sigma=0.01): # 更新缓存 self.cached_train_data = { "nonmnfld_pnts": _nonmnfld_pnts, "psdf_pnts": _psdf_pnts, "psdf": _psdf, } else: # 从缓存中读取数据 _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] _psdf_pnts = self.cached_train_data["psdf_pnts"] _psdf = self.cached_train_data["psdf"] logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) # 将数据分成多个batch num_points = self.train_surf_ncs.shape[0] num_batches = (num_points + batch_size - 1) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_points) # 获取当前batch的数据 mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 # 非流形点使用缓存数据(整个batch共享) nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] psdf_pnts = _psdf_pnts[start_idx:end_idx] psdf = _psdf[start_idx:end_idx] # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 psdf_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_background( mnfld_pnts ) nonmnfld_pred = self.model.forward_background( nonmnfld_pnts ) psdf_pred = self.model.forward_background( psdf_pnts ) # --- 计算损失 --- loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 loss_details = {} try: # --- 3. 检查损失计算前的输入 --- # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") if args.use_normal: # 检查法线和带梯度的点 #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, nonmnfld_pnts, psdf_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, nonmnfld_pred, psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) # --- 4. 检查损失计算结果 --- if self.debug_mode: logger.print_tensor_stats("psdf",psdf) logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) if check_tensor(loss, "Calculated Loss", epoch, step): logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") if loss_details: logger.error(f"Loss Details: {loss_details}") return float('inf') # 如果损失无效,停止这个epoch except Exception as loss_e: logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) return float('inf') # 如果计算出错,停止这个epoch logger.gpu_memory_stats("损失计算后") # --- 反向传播和优化 --- try: # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 self.scheduler.step(loss,epoch) except Exception as backward_e: logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) # 如果你想看是哪个操作导致的,可以启用 anomaly detection # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 return float('inf') # 如果反向传播或优化出错,停止这个epoch # --- 记录和累加损失 --- current_loss = loss.item() if not np.isfinite(current_loss): # 再次确认损失是有效的数值 logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") return float('inf') total_loss += current_loss del loss torch.cuda.empty_cache() if epoch % 100 == 0: # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") return total_loss # 对于单批次训练,直接返回当前损失 def train_epoch_stage2(self, epoch: int) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 if self.sdf_data is None: logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") return float('inf') self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 batch_size = 8192 * 2 # 设置合适的batch大小 # 数据处理 # manfld _mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 _normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 _gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 # 检查是否需要重新计算缓存 if epoch % 10 == 1 or self.cached_train_data is None: # 计算流形点的掩码和操作符 _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) # 生成非流形点 _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) # 更新缓存 self.cached_train_data = { "mnfld_face_indices_mask": _mnfld_face_indices_mask, "mnfld_operator": _mnfld_operator, "nonmnfld_pnts": _nonmnfld_pnts, "psdf": _psdf, "nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask, "nonmnfld_operator": _nonmnfld_operator } else: # 从缓存中读取数据 _mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"] _mnfld_operator = self.cached_train_data["mnfld_operator"] _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] _psdf = self.cached_train_data["psdf"] _nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"] _nonmnfld_operator = self.cached_train_data["nonmnfld_operator"] logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) # 将数据分成多个batch num_points = self.sdf_data.shape[0] num_batches = (num_points + batch_size - 1) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_points) # 获取当前batch的数据 mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 # 非流形点使用缓存数据(整个batch共享) nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] psdf = _psdf[start_idx:end_idx] # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model.forward_without_octree( mnfld_pnts, _mnfld_face_indices_mask[start_idx:end_idx], _mnfld_operator[start_idx:end_idx] ) nonmnfld_pred = self.model.forward_without_octree( nonmnfld_pnts, _nonmnfld_face_indices_mask[start_idx:end_idx], _nonmnfld_operator[start_idx:end_idx] ) #logger.print_tensor_stats("psdf",psdf) #logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) # --- 计算损失 --- loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 loss_details = {} try: # --- 3. 检查损失计算前的输入 --- # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") if args.use_normal: # 检查法线和带梯度的点 #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, nonmnfld_pred, psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) # --- 4. 检查损失计算结果 --- if self.debug_mode: if check_tensor(loss, "Calculated Loss", epoch, step): logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") if loss_details: logger.error(f"Loss Details: {loss_details}") return float('inf') # 如果损失无效,停止这个epoch except Exception as loss_e: logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) return float('inf') # 如果计算出错,停止这个epoch logger.gpu_memory_stats("损失计算后") # --- 反向传播和优化 --- try: # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 self.scheduler.step(loss,epoch) except Exception as backward_e: logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) # 如果你想看是哪个操作导致的,可以启用 anomaly detection # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 return float('inf') # 如果反向传播或优化出错,停止这个epoch # --- 记录和累加损失 --- current_loss = loss.item() if not np.isfinite(current_loss): # 再次确认损失是有效的数值 logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") return float('inf') total_loss += current_loss del loss torch.cuda.empty_cache() # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") return total_loss # 对于单批次训练,直接返回当前损失 def train_epoch(self, epoch: int,resample:bool=True) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 if self.sdf_data is None: logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") return float('inf') self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 batch_size = 8192 # 设置合适的batch大小 # 将数据分成多个batch num_points = self.sdf_data.shape[0] num_batches = (num_points + batch_size - 1) // batch_size for batch_idx in range(num_batches): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_points) mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 normals = None if args.use_normal: if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线 logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") return float('inf') normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点 logger.debug((mnfld_pnts,nonmnfld_pnts,psdf)) else: nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 # 执行检查 if self.debug_mode: if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') if args.use_normal: # 只有在请求法线时才检查 normals if check_tensor(normals, "Input Normals", epoch, step): return float('inf') logger.debug(normals) logger.print_tensor_stats("normals-x",normals[0]) logger.print_tensor_stats("normals-y",normals[1]) logger.print_tensor_stats("normals-z",normals[2]) # --- 准备模型输入,启用梯度 --- mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- mnfld_pred = self.model(mnfld_pnts) nonmnfld_pred = self.model(nonmnfld_pnts) if self.debug_mode: # --- 检查前向传播的输出 --- logger.print_tensor_stats("mnfld_pred",mnfld_pred) logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) logger.gpu_memory_stats("前向传播后") # --- 2. 检查模型输出 --- #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') # --- 计算损失 --- loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 loss_details = {} try: # --- 3. 检查损失计算前的输入 --- # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") if args.use_normal: # 检查法线和带梯度的点 #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") logger.gpu_memory_stats("计算损失前") loss, loss_details = self.loss_manager.compute_loss( mnfld_pnts, nonmnfld_pnts, normals, # 传递检查过的 normals gt_sdf, mnfld_pred, nonmnfld_pred, psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) # --- 4. 检查损失计算结果 --- if self.debug_mode: if check_tensor(loss, "Calculated Loss", epoch, step): logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") if loss_details: logger.error(f"Loss Details: {loss_details}") return float('inf') # 如果损失无效,停止这个epoch except Exception as loss_e: logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) return float('inf') # 如果计算出错,停止这个epoch logger.gpu_memory_stats("损失计算后") # --- 反向传播和优化 --- try: # 反向传播 self.scheduler.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.scheduler.optimizer.step() # 更新参数 self.scheduler.step(loss,epoch) except Exception as backward_e: logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) # 如果你想看是哪个操作导致的,可以启用 anomaly detection # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 return float('inf') # 如果反向传播或优化出错,停止这个epoch # --- 记录和累加损失 --- current_loss = loss.item() if not np.isfinite(current_loss): # 再次确认损失是有效的数值 logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") return float('inf') total_loss += current_loss del loss torch.cuda.empty_cache() # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") return total_loss # 对于单批次训练,直接返回当前损失 def validate(self, epoch: int) -> float: self.model.eval() total_loss = 0.0 with torch.no_grad(): for batch in self.val_loader: points = batch['points'].to(self.device) gt_sdf = batch['sdf'].to(self.device) pred_sdf = self.model(points) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) total_loss += loss.item() avg_loss = total_loss / len(self.val_loader) logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') return avg_loss def train(self): best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() self.cached_train_data=None start_epoch = 1 if args.resume_checkpoint_path: start_epoch = self._load_checkpoint(args.resume_checkpoint_path) logger.info(f"Loaded model from {args.resume_checkpoint_path}") self.model.encoder.freeze_stage1() for epoch in range(start_epoch, self.config.train.num_epochs + 1): # 训练一个epoch train_loss = self.train_epoch_stage1(epoch) #train_loss = self.train_epoch_stage2(epoch) #train_loss = self.train_epoch(epoch) # 验证 ''' if epoch % self.config.train.val_freq == 0: val_loss = self.validate(epoch) logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') # 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss self._save_model(epoch, val_loss) logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') else: logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') ''' # 保存检查点 if epoch % self.config.train.save_freq == 0: self._save_checkpoint(epoch, train_loss) logger.info(f'Checkpoint saved at epoch {epoch}') self.model.encoder.unfreeze() # 训练完成 total_time = time.time() - start_time self._tracing_model_by_script() #self._tracing_model() logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') #self.test_load() def test_load(self): model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt") model.eval() logger.debug(model) example_input = torch.rand(10, 3, device=self.device) #logger.debug(model.encoder.octree.bbox) logger.debug(f"points: {example_input}") sdfs= model(example_input) logger.debug(f"sdfs:{sdfs}") def _tracing_model_by_script(self): """保存模型""" self.model.eval() # 确保模型中的所有逻辑都兼容 TorchScript scripted_model = torch.jit.script(self.model) #optimized_model = optimize_for_mobile(scripted_model) torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _tracing_model(self): """保存模型""" self.model.eval() # 创建示例输入 example_input = torch.rand(1, 3, device=self.device) # 使用 trace 方式导出模型 traced_model = torch.jit.trace(self.model, example_input) # 保存模型 save_path = f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt" torch.jit.save(traced_model, save_path) # 验证保存的模型 try: loaded_model = torch.jit.load(save_path) test_input = torch.rand(1, 3, device=self.device) _ = loaded_model(test_input) logger.info(f"模型已保存并验证成功:{save_path}") except Exception as e: logger.error(f"模型验证失败:{e}") def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" checkpoint_dir = os.path.join( self.config.train.checkpoint_dir, self.model_name ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch:03d}.pth") # 只保存状态字典 torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.scheduler.optimizer.state_dict(), 'loss': train_loss, }, checkpoint_path) def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态""" try: checkpoint = torch.load(checkpoint_path) self.model.load_state_dict(checkpoint['model_state_dict']) self.scheduler.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch'] + 1 except Exception as e: logger.error(f"加载checkpoint失败: {str(e)}") raise def _save_octree(self): """ 保存八叉树到文件。 八叉树保存路径基于模型名称和配置中的检查点目录。 """ checkpoint_dir = os.path.join( self.config.train.checkpoint_dir, self.model_name ) os.makedirs(checkpoint_dir, exist_ok=True) octree_path = os.path.join(checkpoint_dir, "octree.pth") try: # 保存八叉树的根节点 torch.save(self.root, octree_path) logger.info(f"八叉树已保存到 {octree_path}") except Exception as e: logger.error(f"保存八叉树失败: {str(e)}") def _load_octree(self)->bool: """ 从文件加载八叉树。 尝试从基于模型名称和配置检查点目录的路径加载八叉树。 """ checkpoint_dir = os.path.join( self.config.train.checkpoint_dir, self.model_name ) octree_path = os.path.join(checkpoint_dir, "octree.pth") try: if os.path.exists(octree_path): # 加载八叉树的根节点 self.root = torch.load(octree_path, weights_only=False) logger.info(f"八叉树已从 {octree_path} 加载") return True else: logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。") except Exception as e: logger.error(f"加载八叉树失败: {str(e)}") return False def main(): # 这里需要初始化配置 config = get_default_config() # 初始化训练器并开始训练 trainer = Trainer(config, input_step=args.input) trainer.train() if __name__ == '__main__': main()