From b2b42918f429b7a0bc971693bee6d0fc039f9eb6 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 23 Nov 2024 20:48:23 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E4=BF=AE=E5=A4=8Dloss=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/config/default_config.py | 13 +++++- brep2sdf/networks/loss.py | 70 +++++++++++++++++++++++++++++++ brep2sdf/train.py | 24 ++++++----- 3 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 brep2sdf/networks/loss.py diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index f2a392e..f5e8c78 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -23,6 +23,7 @@ class ModelConfig: encoder_channels: Tuple[int] = (32, 64, 128) encoder_layers_per_block: int = 1 + @dataclass class DataConfig: """数据相关配置""" @@ -53,14 +54,15 @@ class TrainConfig: # 基本训练参数 batch_size: int = 1 num_workers: int = 4 - num_epochs: int = 100 + num_epochs: int = 10 learning_rate: float = 1e-4 min_lr: float = 1e-6 weight_decay: float = 0.01 # 梯度和损失相关 - grad_weight: float = 0.1 max_grad_norm: float = 1.0 + clamping_distance: float = 0.1 + # 学习率调度器参数 lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] @@ -70,6 +72,10 @@ class TrainConfig: save_freq: int = 10 # 每多少个epoch保存一次 val_freq: int = 1 # 每多少个epoch验证一次 +@dataclass +class TestConfig: + vis_freq: int = 100 + @dataclass class LogConfig: """日志相关配置""" @@ -91,6 +97,7 @@ class Config: model: ModelConfig = ModelConfig() data: DataConfig = DataConfig() train: TrainConfig = TrainConfig() + test: TestConfig = TestConfig() log: LogConfig = LogConfig() def __post_init__(self): @@ -104,12 +111,14 @@ class Config: model_config = ModelConfig(**config_dict.get('model', {})) data_config = DataConfig(**config_dict.get('data', {})) train_config = TrainConfig(**config_dict.get('train', {})) + test_config = TestConfig(**config_dict.get('test', {})) log_config = LogConfig(**config_dict.get('log', {})) return cls( model=model_config, data=data_config, train=train_config, + test=test_config, log=log_config ) diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py new file mode 100644 index 0000000..adbe2cf --- /dev/null +++ b/brep2sdf/networks/loss.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + +from brep2sdf.config.default_config import get_default_config + +class Brep2SDFLoss: + """解释Brep2SDF的loss设计原理""" + def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1): + self.l1_loss = nn.L1Loss(reduction='sum') + self.enforce_minmax = enforce_minmax + self.minT = -clamping_distance + self.maxT = clamping_distance + + def __call__(self, pred_sdf, gt_sdf): + """使类可直接调用""" + return self.forward(pred_sdf, gt_sdf) + + def forward(self, pred_sdf, gt_sdf): + """ + pred_sdf: 预测的SDF值 + gt_sdf: 真实的SDF值 + latent_vecs: 形状编码, 来自 brep + epoch: 当前训练轮次 + """ + + if self.enforce_minmax: + pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT) + gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT) + + # 1. L1 Loss的优势 + # - 对异常值更鲁棒 + # - 能更好地保持表面细节 + base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0] + + + return base_loss + + +def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): + """SDF损失函数""" + # 确保points需要梯度 + if not points.requires_grad: + points = points.detach().requires_grad_(True) + + # L1损失 + l1_loss = F.l1_loss(pred_sdf, gt_sdf) + + try: + # 梯度约束损失 + grad = torch.autograd.grad( + pred_sdf.sum(), + points, + create_graph=True, + retain_graph=True, + allow_unused=True + )[0] + + if grad is not None: + grad_constraint = F.mse_loss( + torch.norm(grad, dim=-1), + torch.ones_like(pred_sdf.squeeze(-1)) + ) + else: + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) + + except Exception as e: + logger.warning(f"Gradient computation failed: {str(e)}") + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) + + return l1_loss + grad_weight * grad_constraint \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index b955840..ee098a1 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -4,7 +4,8 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from brep2sdf.data.data import BRepSDFDataset -from brep2sdf.networks.network import BRepToSDF, sdf_loss +from brep2sdf.networks.network import BRepToSDF +from brep2sdf.networks.loss import Brep2SDFLoss from brep2sdf.utils.logger import logger from brep2sdf.config.default_config import get_default_config, load_config import wandb @@ -45,6 +46,11 @@ class Trainer: def __init__(self, config): self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + clamping_distance = self.config.train.clamping_distance + self.criterion = Brep2SDFLoss( + enforce_minmax= (clamping_distance > 0), + clamping_distance= clamping_distance + ) use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory logger.info(f"Using device: {self.device}") @@ -126,11 +132,9 @@ class Trainer: ) # 计算损失 - loss = sdf_loss( - pred_sdf, - gt_sdf, - points, - grad_weight=self.config.train.grad_weight + loss = self.criterion( + pred_sdf=pred_sdf, + gt_sdf=gt_sdf, ) # 反向传播和优化 @@ -185,11 +189,9 @@ class Trainer: ) # 计算损失 - loss = sdf_loss( - pred_sdf, - gt_sdf, - points, - grad_weight=self.config.train.grad_weight + loss = self.criterion( + pred_sdf=pred_sdf, + gt_sdf=gt_sdf, ) total_loss += loss.item()