diff --git a/brep2sdf/batch_train.py b/brep2sdf/batch_train.py index 86d2678..94cb2e1 100644 --- a/brep2sdf/batch_train.py +++ b/brep2sdf/batch_train.py @@ -70,7 +70,8 @@ def batch_train(args): common_train_args = [ "--use-normal", "--only-zero-surface", - #"--force-reprocess", + "--octree-cuda", + "--force-reprocess", # 可以添加更多通用参数 ] if args.train_args: @@ -249,8 +250,8 @@ def batch_Iso(args): logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}") def main(args): - #batch_train(args) - batch_Iso(args) + batch_train(args) + #batch_Iso(args) if __name__ == '__main__': diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 0874893..2d0c074 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -29,10 +29,17 @@ class Encoder(nn.Module): ) for i, bbox in enumerate(volume_bboxs) ]) - self.background = PatchFeatureVolume( - bbox=torch.Tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]), # 修正后的归一化bbox - resolution=int(resolutions.max()) * 2, - feature_dim=feature_dim + self.background = self.simple_encoder = nn.Sequential( + nn.Linear(3, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, feature_dim) ) print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}") print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB") @@ -87,6 +94,20 @@ class Encoder(nn.Module): all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask] return all_features + + + def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: + """ + 修改后的前向传播,返回所有关联volume的特征矩阵 + + 参数: + query_points: 查询点坐标 (B, 3) + + 返回: + 特征张量 (B, D) + """ + background_features = self.background.forward(query_points) # (B, D) + return background_features @torch.jit.export def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor: @@ -138,3 +159,22 @@ class Encoder(nn.Module): _move_node(child) _move_node(self.octree.root) return self + + def freeze_stage1(self): + for volume in self.feature_volumes: + for param in volume.parameters(): + param.requires_grad = False + for param in self.background.parameters(): + param.requires_grad = False + def freeze_stage2(self): + for volume in self.feature_volumes: + for param in volume.parameters(): + param.requires_grad = True + for param in self.background.parameters(): + param.requires_grad = False + def unfreeze(self): + for volume in self.feature_volumes: + for param in volume.parameters(): + param.requires_grad = True + for param in self.background.parameters(): + param.requires_grad = True \ No newline at end of file diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index 7257a14..8b13df2 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn class PatchFeatureVolume(nn.Module): - def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=64, padding_ratio=0.05): + def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=8, padding_ratio=0.05): super(PatchFeatureVolume, self).__init__() # 将输入bbox转换为[min, max]格式 self.resolution = resolution @@ -19,8 +19,9 @@ class PatchFeatureVolume(nn.Module): grid_x, grid_y, grid_z = torch.meshgrid(x, y, z) self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1)) - # 初始化特征向量 - self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim)) + # 初始化特征向量为很小的值,使用较小的标准差 + self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim)) + torch.nn.init.normal_(self.feature_volume, mean=0.0, std=0.01) # 标准差设置为 0.01,可根据需要调整 def _expand_bbox(self, min_coords, max_coords, ratio): # 扩展包围盒范围 diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 9c48abc..1f282e1 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -5,14 +5,15 @@ from brep2sdf.utils.logger import logger class LossManager: def __init__(self, ablation, **condition_kwargs): self.weights = { - "manifold": 10, + "manifold": 1, "feature_manifold": 1, # 原文里面和manifold的权重是一样的 "normals": 1, "eikonal": 1, "offsurface": 1, "consistency": 1, "correction": 1, - "psdf": 10 + "psdf": 1, + "psdf_sign_loss": 0 } self.condition_kwargs = condition_kwargs self.ablation = ablation # 消融实验用 @@ -111,6 +112,21 @@ class LossManager: correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失 return correction_loss + def psdf_loss(self, pred_sdfs, gt_sdfs): + # 定义符号相同和不同时的权重 + weight_same_sign = 1.0 # 符号相同时的权重 + weight_different_sign = 10.0 # 符号不同时的权重 + + # 判断符号是否相同 + same_sign = (pred_sdfs * gt_sdfs) >= 0 + + # 根据符号设置权重 + weights = torch.where(same_sign, weight_same_sign, weight_different_sign) + + squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2) + weighted_squared_diff = weights * squared_diff + return torch.mean(weighted_squared_diff) + def compute_loss(self, mnfld_pnts, @@ -150,7 +166,7 @@ class LossManager: # 计算修正损失 #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) - psdf_loss = self.position_loss(nonmnfld_pred, psdfs) + psdf_loss = self.psdf_loss(nonmnfld_pred, psdfs) # 汇总损失 loss_details = { @@ -158,7 +174,7 @@ class LossManager: "normals": self.weights["normals"] * normals_loss, "eikonal": self.weights["eikonal"] * eikonal_loss, "offsurface": self.weights["offsurface"] * offsurface_loss, - "psdf":self.weights["psdf"] * psdf_loss + "psdf":self.weights["psdf"] * psdf_loss, } # 计算总损失 diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index dcfb81a..0d82a43 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -56,7 +56,7 @@ class Net(nn.Module): def __init__(self, octree, volume_bboxs, - feature_dim=64, + feature_dim=8, decoder_output_dim=1, decoder_hidden_dim=256, decoder_num_layers=4, @@ -87,8 +87,7 @@ class Net(nn.Module): output = f_i[:,0] # 提取有效值并填充到固定大小 (B, max_patches) padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches) - valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P) - masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf + masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf # 对每个样本取前 max_patches 个有效值 (B, max_patches) valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值 @@ -108,7 +107,6 @@ class Net(nn.Module): if mask_convex.any(): output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values - logger.debug("step over") #logger.gpu_memory_stats("combine后") return output @@ -137,6 +135,45 @@ class Net(nn.Module): #logger.debug("step combine") return self.process_sdf(f_i, face_indices_mask, operator) + + @torch.jit.export + def forward_background(self, query_points): + """ + 前向传播 + + 参数: + query_point: 查询点的位置坐标 + 返回: + output: 解码后的输出结果 + """ + # 批量查询所有点的索引和bbox + # 编码 + feature_vectors = self.encoder.forward_background(query_points) + # 解码 + h = self.decoder.forward_training_volumes(feature_vectors) # (B, D) + return h + + @torch.jit.export + def forward_without_octree(self, query_points,face_indices_mask,operator): + """ + 前向传播 + + 参数: + query_point: 查询点的位置坐标 + 返回: + output: 解码后的输出结果 + """ + # 批量查询所有点的索引和bbox + #logger.debug("step encode") + # 编码 + feature_vectors = self.encoder.forward(query_points,face_indices_mask) + #print("feature_vector:", feature_vectors.shape) + # 解码 + f_i = self.decoder(feature_vectors) # (B, P) + #logger.gpu_memory_stats("decoder farward后") + + #logger.debug("step combine") + return self.process_sdf(f_i, face_indices_mask, operator) @torch.jit.export def forward_training_volumes(self, surf_points, patch_id:int): @@ -154,12 +191,27 @@ class Net(nn.Module): def gradient(inputs, outputs): + # 问题点1:inputs可能包含非坐标特征 + # 问题点2:未处理batch维度特殊情况 d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) + + # 改进计算方式 points_grad = grad( outputs=outputs, inputs=inputs, grad_outputs=d_points, create_graph=True, retain_graph=True, - only_inputs=True)[0][:, -3:] - return points_grad \ No newline at end of file + only_inputs=True, + allow_unused=True # 新增异常处理 + )[0] + + # 修正维度切片方式 + if points_grad is None: + return torch.zeros_like(inputs[:, -3:]) # 处理空梯度情况 + + # 添加安全截取和归一化 + coord_grad = points_grad[:, -3:] if points_grad.shape[1] >=3 else points_grad + coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-6) # 安全归一化 + + return coord_grad \ No newline at end of file diff --git a/brep2sdf/networks/sample.py b/brep2sdf/networks/sample.py index 2e12f74..f8dbbaf 100644 --- a/brep2sdf/networks/sample.py +++ b/brep2sdf/networks/sample.py @@ -3,7 +3,7 @@ import torch class NormalPerPoint(): - def __init__(self, global_sigma, local_sigma=0.5): + def __init__(self, global_sigma, local_sigma=0.001): self.global_sigma = global_sigma self.local_sigma = local_sigma diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 5464004..543c082 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -75,6 +75,8 @@ class Trainer: self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface) logger.gpu_memory_stats("数据预处理后") + + self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) # # 将曲面点云列表转换为 (N*M, 4) 数组 surfs = self.data["surf_ncs"] @@ -98,6 +100,7 @@ class Trainer: else: self.sdf_data = surface_sdf_data print_data_distribution(self.sdf_data) + logger.debug(self.sdf_data.shape) logger.gpu_memory_stats("SDF数据准备后") # 初始化数据集 #self.brep_data = load_brep_file(self.config.data.pkl_path) @@ -200,7 +203,7 @@ class Trainer: # # 返回合并后的边界框 # return torch.cat([global_min, global_max]) # return [-0.5,] # 这个是错误的 - def train_epoch_stage1(self, epoch: int): + def train_epoch_stage1_(self, epoch: int): total_loss = 0.0 total_loss_details = { "manifold": 0.0, @@ -293,8 +296,303 @@ class Trainer: return total_loss # 返回平均损失而非累计值 + def train_epoch_stage1(self, epoch: int) -> float: + # --- 1. 检查输入数据 --- + # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) + # 并且 SDF 值总是在最后一列 + if self.train_surf_ncs is None: + logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.") + return float('inf') + + self.model.train() + total_loss = 0.0 + step = 0 # 如果你的训练是分批次的,这里应该用批次索引 + batch_size = 8192 # 设置合适的batch大小 + + # 数据处理 + # manfld + _mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点 + _normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线 + _gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值 + + # 检查是否需要重新计算缓存 + if epoch % 10 == 1 or self.cached_train_data is None: + # 计算流形点的掩码和操作符 + # 生成非流形点 + _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) + + # 更新缓存 + self.cached_train_data = { + "nonmnfld_pnts": _nonmnfld_pnts, + "psdf": _psdf, + } + else: + # 从缓存中读取数据 + _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] + _psdf = self.cached_train_data["psdf"] + + logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) + + + + # 将数据分成多个batch + num_points = self.train_surf_ncs.shape[0] + num_batches = (num_points + batch_size - 1) // batch_size + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, num_points) + + # 获取当前batch的数据 + mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 + gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 + normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 + + # 非流形点使用缓存数据(整个batch共享) + nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] + psdf = _psdf[start_idx:end_idx] + + # --- 准备模型输入,启用梯度 --- + mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + + # --- 前向传播 --- + mnfld_pred = self.model.forward_background( + mnfld_pnts + ) + nonmnfld_pred = self.model.forward_background( + nonmnfld_pnts + ) + + + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + # --- 3. 检查损失计算前的输入 --- + # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) + #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") + #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") + if args.use_normal: + # 检查法线和带梯度的点 + #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") + #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") + logger.gpu_memory_stats("计算损失前") + loss, loss_details = self.loss_manager.compute_loss( + mnfld_pnts, + nonmnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + + # --- 4. 检查损失计算结果 --- + if self.debug_mode: + logger.print_tensor_stats("psdf",psdf) + logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + logger.gpu_memory_stats("损失计算后") + + # --- 反向传播和优化 --- + try: + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(loss,epoch) + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + # 如果你想看是哪个操作导致的,可以启用 anomaly detection + # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + + # --- 记录和累加损失 --- + current_loss = loss.item() + if not np.isfinite(current_loss): # 再次确认损失是有效的数值 + logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") + return float('inf') + + total_loss += current_loss + del loss + torch.cuda.empty_cache() + + if epoch % 100 == 0: + # 记录训练进度 (只记录有效的损失) + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {current_loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + + return total_loss # 对于单批次训练,直接返回当前损失 + + def train_epoch_stage2(self, epoch: int) -> float: + # --- 1. 检查输入数据 --- + # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) + # 并且 SDF 值总是在最后一列 + if self.sdf_data is None: + logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") + return float('inf') + + self.model.train() + total_loss = 0.0 + step = 0 # 如果你的训练是分批次的,这里应该用批次索引 + batch_size = 8192 * 2 # 设置合适的batch大小 + + # 数据处理 + # manfld + _mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 + _normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 + _gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 + + # 检查是否需要重新计算缓存 + if epoch % 10 == 1 or self.cached_train_data is None: + # 计算流形点的掩码和操作符 + _, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts) + + # 生成非流形点 + _nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals) + _, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts) + + # 更新缓存 + self.cached_train_data = { + "mnfld_face_indices_mask": _mnfld_face_indices_mask, + "mnfld_operator": _mnfld_operator, + "nonmnfld_pnts": _nonmnfld_pnts, + "psdf": _psdf, + "nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask, + "nonmnfld_operator": _nonmnfld_operator + } + else: + # 从缓存中读取数据 + _mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"] + _mnfld_operator = self.cached_train_data["mnfld_operator"] + _nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"] + _psdf = self.cached_train_data["psdf"] + _nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"] + _nonmnfld_operator = self.cached_train_data["nonmnfld_operator"] + + logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) + + + + # 将数据分成多个batch + num_points = self.sdf_data.shape[0] + num_batches = (num_points + batch_size - 1) // batch_size + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, num_points) + + # 获取当前batch的数据 + mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点 + gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值 + normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线 + + # 非流形点使用缓存数据(整个batch共享) + nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx] + psdf = _psdf[start_idx:end_idx] + + # --- 准备模型输入,启用梯度 --- + mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 + + # --- 前向传播 --- + mnfld_pred = self.model.forward_without_octree( + mnfld_pnts, + _mnfld_face_indices_mask[start_idx:end_idx], + _mnfld_operator[start_idx:end_idx] + ) + nonmnfld_pred = self.model.forward_without_octree( + nonmnfld_pnts, + _nonmnfld_face_indices_mask[start_idx:end_idx], + _nonmnfld_operator[start_idx:end_idx] + ) - def train_epoch(self, epoch: int) -> float: + #logger.print_tensor_stats("psdf",psdf) + #logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred) + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + # --- 3. 检查损失计算前的输入 --- + # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) + #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") + #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") + if args.use_normal: + # 检查法线和带梯度的点 + #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") + #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") + logger.gpu_memory_stats("计算损失前") + loss, loss_details = self.loss_manager.compute_loss( + mnfld_pnts, + nonmnfld_pnts, + normals, # 传递检查过的 normals + gt_sdf, + mnfld_pred, + nonmnfld_pred, + psdf + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + + # --- 4. 检查损失计算结果 --- + if self.debug_mode: + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + logger.gpu_memory_stats("损失计算后") + + # --- 反向传播和优化 --- + try: + # 反向传播 + self.scheduler.optimizer.zero_grad() # 清空梯度 + loss.backward() # 反向传播 + self.scheduler.optimizer.step() # 更新参数 + self.scheduler.step(loss,epoch) + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + # 如果你想看是哪个操作导致的,可以启用 anomaly detection + # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + + # --- 记录和累加损失 --- + current_loss = loss.item() + if not np.isfinite(current_loss): # 再次确认损失是有效的数值 + logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") + return float('inf') + + total_loss += current_loss + del loss + torch.cuda.empty_cache() + + + # 记录训练进度 (只记录有效的损失) + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {current_loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + + return total_loss # 对于单批次训练,直接返回当前损失 + + def train_epoch(self, epoch: int,resample:bool=True) -> float: # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 @@ -447,16 +745,19 @@ class Trainer: best_val_loss = float('inf') logger.info("Starting training...") start_time = time.time() + self.cached_train_data=None start_epoch = 1 if args.resume_checkpoint_path: start_epoch = self._load_checkpoint(args.resume_checkpoint_path) logger.info(f"Loaded model from {args.resume_checkpoint_path}") + self.model.encoder.freeze_stage1() for epoch in range(start_epoch, self.config.train.num_epochs + 1): # 训练一个epoch - #train_loss = self.train_epoch_stage1(epoch) - train_loss = self.train_epoch(epoch) + train_loss = self.train_epoch_stage1(epoch) + #train_loss = self.train_epoch_stage2(epoch) + #train_loss = self.train_epoch(epoch) # 验证 ''' @@ -476,7 +777,7 @@ class Trainer: if epoch % self.config.train.save_freq == 0: self._save_checkpoint(epoch, train_loss) logger.info(f'Checkpoint saved at epoch {epoch}') - + self.model.encoder.unfreeze() # 训练完成 total_time = time.time() - start_time @@ -555,8 +856,6 @@ class Trainer: logger.error(f"加载checkpoint失败: {str(e)}") raise - # ... existing code ... - def _save_octree(self): """ 保存八叉树到文件。 @@ -566,6 +865,7 @@ class Trainer: self.config.train.checkpoint_dir, self.model_name ) + os.makedirs(checkpoint_dir, exist_ok=True) octree_path = os.path.join(checkpoint_dir, "octree.pth") try: diff --git a/brep2sdf/utils/load.py b/brep2sdf/utils/load.py new file mode 100644 index 0000000..7c5e02d --- /dev/null +++ b/brep2sdf/utils/load.py @@ -0,0 +1,133 @@ + +import os +from concurrent.futures import ProcessPoolExecutor, as_completed +from tqdm import tqdm +import logging + +# 假设 logger 是通过 logging 模块配置的 +logger = logging.getLogger(__name__) + +# utils +def get_namelist(path): + try: + with open(path, 'r') as f: + names = [line.strip() for line in f if line.strip()] + logger.info(f"从 '{path}' 读取了 {len(names)} 个名称。") + return names + except FileNotFoundError: + logger.error(f"错误: 文件 '{path}' 未找到。") + return + except Exception as e: + logger.error(f"读取文件 '{path}' 时出错: {e}") + return + + +def get_step_paths(names, step_root_dir, file_extensions, name_filter=None): + """ + 根据名称列表文件路径,获取所有匹配的 STEP 文件路径。 + + Args: + namelist_path (str): 名称列表文件的路径,该文件包含要处理的名称。 + step_root_dir (str): 步骤文件的根目录,每个名称对应一个子目录。 + file_extensions (list): 要匹配的文件扩展名列表,例如 ['.step', '.stp']。 + name_filter (callable, optional): 文件名过滤函数,接受文件名和名称作为参数,返回布尔值。 + + Returns: + list: 匹配的 STEP 文件路径列表。 + """ + # 获取名称列表 + if names is None: + logger.error("无法获取名称列表,终止任务。") + return [] + + step_file_paths = [] + skipped_count = 0 + + # 遍历每个名称,查找匹配的 STEP 文件 + for name in names: + step_dir = os.path.join(step_root_dir, name) + if not os.path.isdir(step_dir): + logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'。") + skipped_count += 1 + continue + + step_files = [] + try: + # 查找匹配的文件 + step_files = [ + os.path.join(step_dir, f) + for f in os.listdir(step_dir) + if f.lower().endswith(tuple(file_extensions)) and (not name_filter or name_filter(f, name)) + ] + except OSError as e: + logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'。") + skipped_count += 1 + continue + + if len(step_files) == 0: + logger.warning(f"在目录 '{step_dir}' 中未找到匹配的文件。跳过 '{name}'。") + skipped_count += 1 + elif len(step_files) > 1: + logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的文件,将使用第一个: {step_files[0]}。") + step_file_paths.append(step_files[0]) + else: + step_file_paths.append(step_files[0]) + + logger.info(f"成功获取 {len(step_file_paths)} 个文件路径,跳过了 {skipped_count} 个名称。") + return step_file_paths + +def run_batch_task(task_function, args, common_args_func, file_extensions, name_filter=None): + """ + 通用批量任务处理函数。 + + Args: + task_function: 要执行的任务函数,接受文件路径、脚本路径和通用参数作为参数。 + args: 命令行参数对象。 + common_args_func: 生成通用参数的函数。 + file_extensions: 要匹配的文件扩展名列表。 + name_filter: 文件名过滤函数,可选。 + + Returns: + None + """ + # 获取任务文件路径 + tasks = get_step_paths(args.name_list_path, args.step_root_dir, file_extensions, name_filter) + if not tasks: + logger.info("没有找到需要处理的有效文件。") + return + + # 准备通用参数 + common_args = common_args_func(args) + + success_count = 0 + failure_count = 0 + skipped_count = len(get_namelist(args.name_list_path) or []) - len(tasks) + + # 使用 ProcessPoolExecutor 进行并行处理 + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = { + executor.submit(task_function, task_path, args.train_script, common_args): task_path + for task_path in tasks + } + + # 使用 tqdm 显示进度并处理结果 + for future in tqdm(as_completed(futures), total=len(tasks), desc="运行任务"): + input_path = futures[future] + try: + input_file, success, stdout, stderr = future.result() + if success: + success_count += 1 + # 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用 + # logger.debug(f"成功处理 '{input_file}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + else: + failure_count += 1 + logger.error(f"处理 '{input_file}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}") + except Exception as e: + failure_count += 1 + logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}") + + logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}") + + + diff --git a/data/scripts/pre_processing1.py b/data/scripts/pre_processing1.py index 0bb6ad0..a825731 100644 --- a/data/scripts/pre_processing1.py +++ b/data/scripts/pre_processing1.py @@ -5,6 +5,7 @@ import multiprocessing from concurrent.futures import ProcessPoolExecutor, as_completed import argparse import time +from brep2sdf.utils.load import get_namelist,get_step_paths from brep2sdf.utils.logger import logger import numpy as np @@ -335,12 +336,23 @@ def test_single_step(step_path, output_obj_path=None, linear_deflection=0.01): print(f"\n处理失败: {str(e)}") return None +def process_for_namelist(): + names = get_namelist("/home/wch/brep2sdf/data/name_list.txt") + + for name in names: + # 使用 glob 获取匹配的文件列表 + step_files = glob.glob(f"/home/wch/brep2sdf/data/step/{name}/{name}*.step") + output = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj" + test_single_step(step_files[0], output_obj_path=output) + + # 准备任务(主进程执行) if __name__ == "__main__": - main() + #main() + process_for_namelist() ''' test_single_step( - "/home/wch/brep2sdf/data/step/00002736/00002736_82034c87704b46a891e498d6_step_004.step", - "/home/wch/brep2sdf/data/gt_mesh/00002736.obj" + "/home/wch/brep2sdf/data/step/00000010/00000010_b4b99d35e04b4277931f9a9c_step_000.step", + "/home/wch/brep2sdf/data/gt_mesh/00000031.obj" ) '''