From d7227b91a25368c098dc725b7e43c2a1fddf792e Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 28 Apr 2025 19:33:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BC=AAsdf=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 18 ++++++++++++------ brep2sdf/networks/loss.py | 11 ++++++++--- brep2sdf/networks/sample.py | 25 +++++++++++++++++++++++-- brep2sdf/train.py | 9 +++++++-- 4 files changed, 50 insertions(+), 13 deletions(-) diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index fc0f4b1..e7cfb4e 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import Tuple +from dataclasses import dataclass, field +from typing import Tuple,List @dataclass class ModelConfig: @@ -49,15 +49,21 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 100 - learning_rate: float = 0.1 + num_epochs: int = 50 + learning_rate: float = 0.005 + learning_rate_schedule: List = field(default_factory=lambda: [{ + "Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子 + "Initial": 0.005, + "Interval": 2000, + "Factor": 0.5 + }]) min_lr: float = 1e-5 weight_decay: float = 0.01 # 梯度和损失相关 max_grad_norm: float = 1.0 clamping_distance: float = 0.1 - debug_mode: bool = True + debug_mode: bool = False accumulation_steps:int = 50 # 学习率调度器参数 @@ -65,7 +71,7 @@ class TrainConfig: warmup_epochs: int = 5 # 保存和验证 - save_freq: int = 100 # 每多少个epoch保存一次 + save_freq: int = 10 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 # 保存路径 diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index fe269eb..9c48abc 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -5,13 +5,14 @@ from brep2sdf.utils.logger import logger class LossManager: def __init__(self, ablation, **condition_kwargs): self.weights = { - "manifold": 1, + "manifold": 10, "feature_manifold": 1, # 原文里面和manifold的权重是一样的 "normals": 1, "eikonal": 1, "offsurface": 1, "consistency": 1, "correction": 1, + "psdf": 10 } self.condition_kwargs = condition_kwargs self.ablation = ablation # 消融实验用 @@ -117,7 +118,8 @@ class LossManager: normals, gt_sdfs, mnfld_pred, - nonmnfld_pred): + nonmnfld_pred, + psdfs): """ 计算流型损失的逻辑 @@ -148,12 +150,15 @@ class LossManager: # 计算修正损失 #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) + psdf_loss = self.position_loss(nonmnfld_pred, psdfs) + # 汇总损失 loss_details = { "manifold": self.weights["manifold"] * manifold_loss, "normals": self.weights["normals"] * normals_loss, "eikonal": self.weights["eikonal"] * eikonal_loss, - "offsurface": self.weights["offsurface"] * offsurface_loss + "offsurface": self.weights["offsurface"] * offsurface_loss, + "psdf":self.weights["psdf"] * psdf_loss } # 计算总损失 diff --git a/brep2sdf/networks/sample.py b/brep2sdf/networks/sample.py index fa4b02c..bcb16ae 100644 --- a/brep2sdf/networks/sample.py +++ b/brep2sdf/networks/sample.py @@ -3,7 +3,7 @@ import torch class NormalPerPoint(): - def __init__(self, global_sigma, local_sigma=0.01): + def __init__(self, global_sigma, local_sigma=0.1): self.global_sigma = global_sigma self.local_sigma = local_sigma @@ -19,4 +19,25 @@ class NormalPerPoint(): sample = torch.cat([sample_local, sample_global], dim=1) - return sample \ No newline at end of file + return sample + + def get_norm_points(self, pc_input, normals, local_sigma=None): + """ + 返回沿法线方向偏移的点以及对应的伪 SDF 值(PSDF) + :param pc_input: 输入点云,形状为 (sample_size, 3) + :param normals: 点云的法线,形状为 (sample_size, 3) + :param local_sigma: 局部偏移的标准差 + :return: 偏移后的点 (sample_size, dim), 伪 SDF 值 + """ + sample_size, dim = pc_input.shape + + # 生成随机位移值 + if local_sigma is not None: + psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma + else: + psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma + + # 沿法线方向偏移 + sample = pc_input + normals * psdf.unsqueeze(-1) + + return sample, psdf \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index daa0a62..5464004 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -315,7 +315,7 @@ class Trainer: start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, num_points) mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 - nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 + gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 normals = None if args.use_normal: @@ -323,7 +323,11 @@ class Trainer: logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") return float('inf') normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 + nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点 + logger.debug((mnfld_pnts,nonmnfld_pnts,psdf)) + else: + nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点 # 执行检查 if self.debug_mode: if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') @@ -372,7 +376,8 @@ class Trainer: normals, # 传递检查过的 normals gt_sdf, mnfld_pred, - nonmnfld_pred + nonmnfld_pred, + psdf ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)