From dd838ea22f4c295aaa078f3ede3b2dd74d92dfaf Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 25 Apr 2025 23:55:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=81=E5=9E=8B=E6=8D=9F=E5=A4=B1=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E8=AE=AD=E7=BB=83=EF=BC=8C=E6=B3=95=E7=BA=BF=E6=8D=9F?= =?UTF-8?q?=E5=A4=B1=E6=9C=89=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 5 +- brep2sdf/networks/decoder.py | 13 ++++- brep2sdf/networks/encoder.py | 37 +++++++++----- brep2sdf/networks/loss.py | 36 +++++--------- brep2sdf/networks/network.py | 5 +- brep2sdf/train.py | 82 +++++++++++++++++-------------- 6 files changed, 99 insertions(+), 79 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 8dd4537..9dcf142 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -49,8 +49,8 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 1 - learning_rate: float = 0.001 + num_epochs: int = 1000 + learning_rate: float = 0.1 min_lr: float = 1e-5 weight_decay: float = 0.01 @@ -58,6 +58,7 @@ class TrainConfig: max_grad_norm: float = 1.0 clamping_distance: float = 0.1 debug_mode: bool = True + accumulation_steps:int = 50 # 学习率调度器参数 lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 373dab9..824c516 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -48,7 +48,16 @@ class Decoder(nn.Module): torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) self.sdf_modules.append(lin) + # 添加层归一化 + self.norm_layers = nn.ModuleList() + for dim in dims_sdf[1:-1]: + self.norm_layers.append(nn.LayerNorm(dim)) + if geometric_init: + self.activation = nn.Sequential( + nn.LayerNorm(out_dim), # 添加层归一化 + nn.Softplus(beta=beta) + ) if beta > 0: self.activation = nn.Softplus(beta=beta) # vanilla relu @@ -99,17 +108,19 @@ class Decoder(nn.Module): ''' # 直接使用输入的特征矩阵,因为形状已经是 (S, D) x = feature_matrix + logger.debug(f"decoder-x:{x}") for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt x = lin(x) + logger.debug(f"decoder-x-lin:{x}") if layer < self.sdf_layers - 2: x = self.activation(x) output_value = x # 所有 f 的值 - + logger.debug(f"decoder-output:{output_value}") # 调整输出形状为 (S) f = output_value.squeeze(-1) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 653513a..541accb 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -28,6 +28,12 @@ class Encoder(nn.Module): feature_dim=feature_dim ) for i, bbox in enumerate(volume_bboxs) ]) + + self.background = PatchFeatureVolume( + bbox=torch.Tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]), # 修正后的归一化bbox + resolution=int(resolutions.max()) * 2, + feature_dim=feature_dim + ) print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}") print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB") @@ -60,19 +66,19 @@ class Encoder(nn.Module): 参数: query_points: 查询点坐标 (B, 3) - volume_indices: 关联的volume索引矩阵 (B, K) + volume_indices: 关联的volume索引矩阵 (B, P) 返回: - 特征张量 (B, K, D) + 特征张量 (B, P, D) """ batch_size, num_volumes = volume_indices.shape all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, device=query_points.device) - + background_features = self.background.forward(query_points) # (B, D) # 遍历每个volume索引 - for k in range(num_volumes): + for p in range(num_volumes): # 获取当前volume的索引 (B,) - current_indices = volume_indices[:, k] + current_indices = volume_indices[:, p] # 遍历所有存在的volume for vol_id, volume in enumerate(self.feature_volumes): @@ -81,8 +87,8 @@ class Encoder(nn.Module): if mask.any(): # 获取对应volume的特征 (M, D) features = volume.forward(query_points[mask]) - all_features[mask, k] = features - + all_features[mask, p] = 0.7 * features + 0.3 * background_features[mask] + return all_features @torch.jit.export @@ -94,13 +100,20 @@ class Encoder(nn.Module): 返回: 特征张量 (S, D) """ - # 使用枚举遍历 feature_volumes,避免直接索引 + # 获取 patch 特征 + patch_features = torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device) for idx, volume in enumerate(self.feature_volumes): if idx == patch_id: - return volume.forward(surf_points) - return torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device) - - return features + patch_features = volume.forward(surf_points) + break + + # 获取背景场特征 + background_features = self.background.forward(surf_points) + + # 叠加 patch 和背景场特征 + combined_features = 0.7 * patch_features + 0.3 * background_features + + return combined_features def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor: """优化后的向量化三线性插值""" diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index d6fcdd9..43b65bb 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -47,26 +47,11 @@ class LossManager: return mnfld_pred_h, nonmnfld_pred_h, mnfld_grad, all_fi def position_loss(self, pred_sdfs: torch.Tensor, gt_sdfs: torch.Tensor) -> torch.Tensor: - """ - 计算流型损失的逻辑 - - :param pred_sdfs: 预测的SDF值,形状为 (N, 1) - :param gt_sdfs: 真实的SDF值,形状为 (N, 1) - :return: 计算得到的流型损失,标量 - """ - with torch.no_grad(): # 当前上下文 - # 显式分离张量 - pred_sdfs = pred_sdfs.detach() - gt_sdfs = gt_sdfs.detach() - squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2) - manifold_loss = torch.mean(squared_diff) - - # 显式释放中间变量 - del squared_diff - torch.cuda.empty_cache() # 立即释放缓存 - - return manifold_loss - + """位置损失函数""" + # 保持梯度流 + squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2) + return torch.mean(squared_diff) + def normals_loss(self, normals: torch.Tensor, mnfld_pnts: torch.Tensor, pred_sdfs) -> torch.Tensor: """ 计算法线损失 @@ -136,10 +121,13 @@ class LossManager: :param outputs: 模型的输出 :return: 计算得到的流型损失值 """ - # 计算流形损失 - #logger.gpu_memory_stats("计算流型损失前") - manifold_loss = self.position_loss(pred_sdfs,gt_sdfs) - #logger.gpu_memory_stats("计算流型损失后") + # 强制类型转换确保一致性 + normals = normals.to(torch.float32) + pred_sdfs = pred_sdfs.to(torch.float32) + gt_sdfs = gt_sdfs.to(torch.float32) + + # 计算流形损失 + manifold_loss = self.position_loss(pred_sdfs, gt_sdfs) # 计算法线损失 normals_loss = self.normals_loss(normals, points, pred_sdfs) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index cba4484..d07a770 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -57,7 +57,6 @@ class Net(nn.Module): octree, volume_bboxs, feature_dim=64, - decoder_input_dim=64, decoder_output_dim=1, decoder_hidden_dim=256, decoder_num_layers=4, @@ -76,10 +75,10 @@ class Net(nn.Module): # 初始化 Decoder self.decoder = Decoder( - d_in=decoder_input_dim, + d_in=feature_dim, dims_sdf=[decoder_hidden_dim] * decoder_num_layers, geometric_init=True, - beta=100 + beta=5 ) #self.csg_combiner = CSGCombiner(flag_convex=True) diff --git a/brep2sdf/train.py b/brep2sdf/train.py index d0caa33..87a1fee 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -197,8 +197,17 @@ class Trainer: # return torch.cat([global_min, global_max]) # return [-0.5,] # 这个是错误的 def train_epoch_stage1(self, epoch: int): - total_loss = 0.0 # 初始化总损失 - for step, surf_points in enumerate(self.data['surf_ncs']): # 定义 step 变量 + total_loss = 0.0 + total_loss_details = { + "manifold": 0.0, + "normals": 0.0 + } + accumulated_loss = 0.0 # 新增:用于累积多个step的loss + + # 新增:在每个epoch开始时清零梯度 + self.optimizer.zero_grad() + + for step, surf_points in enumerate(self.data['surf_ncs']): points = torch.tensor(surf_points, device=self.device) gt_sdf = torch.zeros(points.shape[0], device=self.device) normals = None @@ -211,14 +220,13 @@ class Trainer: # --- 前向传播 --- self.optimizer.zero_grad() pred_sdf = self.model.forward_training_volumes(points, step) + logger.debug(f"pred_sdf:{pred_sdf}") if self.debug_mode: # --- 检查前向传播的输出 --- logger.gpu_memory_stats("前向传播后") # --- 计算损失 --- - loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 - loss_details = {} try: if args.use_normal: loss, loss_details = self.loss_manager.compute_loss( @@ -229,45 +237,45 @@ class Trainer: ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) - - if self.debug_mode: - if check_tensor(loss, "Calculated Loss", epoch, step): - logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") - if loss_details: - logger.error(f"Loss Details: {loss_details}") - return float('inf') # 如果损失无效,停止这个epoch + loss_details = {} # 确保变量初始化 + + # 修改:累积loss而不是立即backward + accumulated_loss += loss / self.config.train.accumulation_steps # 假设配置中有accumulation_steps + current_loss = loss.item() + total_loss += current_loss + for key in total_loss_details: + if key in loss_details: + total_loss_details[key] += loss_details[key].item() + + # 新增:达到累积步数时执行反向传播 + if (step + 1) % self.config.train.accumulation_steps == 0: + accumulated_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + self.optimizer.zero_grad() + accumulated_loss = 0.0 # 重置累积loss + + # 记录日志保持不变 ... + except Exception as loss_e: - logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) - return float('inf') # 如果计算出错,停止这个epoch + logger.error(f"Error in step {step}: {loss_e}") + continue - # --- 反向传播和优化 --- - try: - loss.backward() - # --- (推荐) 添加梯度裁剪 --- - # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 - - self.optimizer.step() - except Exception as backward_e: - logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) - return float('inf') # 如果反向传播或优化出错,停止这个epoch - - # --- 记录和累加损失 --- - current_loss = loss.item() - if not np.isfinite(current_loss): # 再次确认损失是有效的数值 - logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") - return float('inf') - - total_loss += current_loss + # --- 内存管理 --- del loss torch.cuda.empty_cache() - # 记录训练进度 (只记录有效的损失) + # 新增:处理最后未达到累积步数的剩余loss + if accumulated_loss != 0: + accumulated_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + + # 计算并记录epoch损失 logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {current_loss:.6f}') - if loss_details: logger.info(f"Loss Details: {loss_details}") - - return total_loss + f'Loss: {total_loss:.6f}') + logger.info(f"Loss Details: {total_loss_details}") + return total_loss # 返回平均损失而非累计值 def train_epoch(self, epoch: int) -> float: