import torch from torch.serialization import add_safe_globals import torch.optim as optim import time import os import numpy as np import argparse from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,load_sdf_file from brep2sdf.data.pre_process_by_mesh import process_single_step from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager from brep2sdf.utils.logger import logger # 配置命令行参数 parser = argparse.ArgumentParser(description='STEP文件批量处理工具') parser.add_argument('-i', '--input', required=True, help='待处理 brep (.step) 路径') parser.add_argument( '--use-normal', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制采样点有法向量' ) parser.add_argument( '--force-reprocess', action='store_true', # 默认为 False,如果用户指定该参数,则为 True help='强制重新进行数据预处理,忽略缓存或已有结果' ) args = parser.parse_args() def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'): total_points = sum(len(s) for s in surf_data) # 降采样逻辑(修复版) if total_points > max_points: # 生成索引 indices = [] for i, points in enumerate(surf_data): indices.extend([(i, j) for j in range(len(points))]) # 随机打乱索引 np.random.shuffle(indices) # 选择前max_points个索引 selected_indices = indices[:max_points] if not normals is None: # 根据索引构建sdf_array sdf_array = np.zeros((max_points, 4), dtype=np.float32) for idx, (i, j) in enumerate(selected_indices): sdf_array[idx, :3] = surf_data[i][j] else: sdf_array = np.zeros((max_points, 7), dtype=np.float32) for idx, (i, j) in enumerate(selected_indices): sdf_array[idx, :3] = surf_data[i][j] sdf_array[idx, 3:6] = normals[i][j] else: if not normals is None: sdf_array = np.zeros((total_points, 4), dtype=np.float32) sdf_array[:, :3] = np.concatenate(surf_data) sdf_array = np.zeros((max_points, 7), dtype=np.float32) else: for idx, (i, j) in enumerate(selected_indices): sdf_array[idx, :3] = surf_data[i][j] sdf_array[idx, 3:6] = normals[i][j] return torch.tensor(sdf_array, dtype=torch.float32, device=device) class Trainer: def __init__(self, config, input_step): self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model_name = os.path.basename(input_step).split('_')[0] self.base_name = self.model_name + ".xyz" data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name) if os.path.exists(data_path) and not args.force_reprocess: try: self.data = load_brep_file(data_path) except Exception as e: logger.error(f"fail to load {data_path}, {str(e)}") raise e if args.use_normal and self.data.get("surf_pnt_normals", None) is None: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) else: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal) # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] self.sdf_data = prepare_sdf_data( surfs, normals = self.data["surf_pnt_normals"], max_points=4096, device=self.device ) # 初始化数据集 #self.brep_data = load_brep_file(self.config.data.pkl_path) #logger.info( self.brep_data ) #self.sdf_data = load_sdf_file(sdf_path=self.config.data.sdf_path, num_query_points=self.config.data.num_query_points).to(self.device) # 初始化网络 surf_bbox=torch.tensor( self.data['surf_bbox_ncs'], dtype=torch.float32, device=self.device ) self.build_tree(surf_bbox=surf_bbox, max_depth=4) self.model = Net( octree=self.root, feature_dim=64 ).to(self.device) # 初始化优化器 self.optimizer = optim.AdamW( self.model.parameters(), lr=config.train.learning_rate, weight_decay=config.train.weight_decay ) self.loss_manager = LossManager(ablation="none") def build_tree(self,surf_bbox, max_depth=6): num_faces = surf_bbox.shape[0] bbox = self._calculate_global_bbox(surf_bbox) self.root = OctreeNode( bbox=bbox, face_indices=np.arange(num_faces), # 初始包含所有面 max_depth=max_depth, surf_bbox=surf_bbox ) #print(surf_bbox) logger.info("starting octree conduction") self.root.conduct_tree() logger.info("complete octree conduction") #self.root.print_tree(0) def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ 计算整个数据集的全局边界框,综合考虑表面包围盒和采样点 参数: surf_bbox: 形状为 (num_edges, 6) 的Tensor,表示每条边的包围盒 [xmin, ymin, zmin, xmax, ymax, zmax] 返回: 形状为 (6,) 的Tensor,格式为 [x_min, y_min, z_min, x_max, y_max, z_max] """ # 验证输入 if not isinstance(surf_bbox, torch.Tensor): raise TypeError(f"surf_bbox 必须是 torch.Tensor,但得到 {type(surf_bbox)}") if surf_bbox.dim() != 2 or surf_bbox.shape[1] != 6: raise ValueError(f"surf_bbox 形状应为 (num_edges, 6),但得到 {surf_bbox.shape}") # 计算表面包围盒的全局范围 global_min = surf_bbox[:, :3].min(dim=0).values global_max = surf_bbox[:, 3:].max(dim=0).values # 返回合并后的边界框 return torch.cat([global_min, global_max]) def train_epoch(self, epoch: int) -> float: self.model.train() total_loss = 0.0 # 获取数据并移动到设备 points = self.sdf_data[:,0:3] points.requires_grad_(True) if args.use_normal: normals = self.sdf_data[:,3:6] gt_sdf = self.sdf_data[:,6] else: gt_sdf = self.sdf_data[:,3] # 前向传播 self.optimizer.zero_grad() pred_sdf = self.model(points) # 计算损失 if args.use_normal: loss,loss_details = self.loss_manager.compute_loss( points, normals, gt_sdf, pred_sdf ) # 计算损失 else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) # 反向传播和优化 loss.backward() self.optimizer.step() total_loss += loss.item() # 记录训练进度 logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {loss.item():.6f}') return total_loss def validate(self, epoch: int) -> float: self.model.eval() total_loss = 0.0 with torch.no_grad(): for batch in self.val_loader: points = batch['points'].to(self.device) gt_sdf = batch['sdf'].to(self.device) pred_sdf = self.model(points) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) total_loss += loss.item() avg_loss = total_loss / len(self.val_loader) logger.info(f'Validation Epoch: {epoch}\tAverage Loss: {avg_loss:.6f}') return avg_loss def train(self): best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() for epoch in range(1, self.config.train.num_epochs + 1): # 训练一个epoch train_loss = self.train_epoch(epoch) # 验证 ''' if epoch % self.config.train.val_freq == 0: val_loss = self.validate(epoch) logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}') # 保存最佳模型 if val_loss < best_val_loss: best_val_loss = val_loss self._save_model(epoch, val_loss) logger.info(f'New best model saved at epoch {epoch} with val loss {val_loss:.6f}') else: logger.info(f'Epoch {epoch}: Train Loss = {train_loss:.6f}') ''' # 保存检查点 if epoch % self.config.train.save_freq == 0: self._save_checkpoint(epoch, train_loss) logger.info(f'Checkpoint saved at epoch {epoch}') # 训练完成 total_time = time.time() - start_time logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s') logger.info(f'Best validation loss: {best_val_loss:.6f}') self._tracing_model() #self.test_load() def test_load(self): model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt") model.eval() logger.debug(model) example_input = torch.rand(10, 3, device=self.device) #logger.debug(model.encoder.octree.bbox) logger.debug(f"points: {example_input}") sdfs= model(example_input) logger.debug(f"sdfs:{sdfs}") def _tracing_model(self): """保存模型""" self.model.eval() # 确保模型中的所有逻辑都兼容 TorchScript scripted_model = torch.jit.script(self.model) torch.jit.save(scripted_model, f"/home/wch/brep2sdf/data/output_data/{self.model_name}.pt") def _load_checkpoint(self, checkpoint_path): """从检查点恢复训练状态""" model = torch.load(checkpoint_path) return model def _save_checkpoint(self, epoch: int, train_loss: float): """保存训练检查点""" checkpoint_dir = os.path.join( self.config.train.checkpoint_dir, self.model_name ) os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth") ''' torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': train_loss, 'config': self.config }, checkpoint_path) ''' torch.save(self.model,checkpoint_path) def main(): # 这里需要初始化配置 config = get_default_config() # 初始化训练器并开始训练 trainer = Trainer(config, input_step=args.input) trainer.train() if __name__ == '__main__': main()