Browse Source

fix:修复loss参数

final
mckay 7 months ago
parent
commit
d8b5b545e2
  1. 13
      brep2sdf/config/default_config.py
  2. 70
      brep2sdf/networks/loss.py
  3. 24
      brep2sdf/train.py

13
brep2sdf/config/default_config.py

@ -23,6 +23,7 @@ class ModelConfig:
encoder_channels: Tuple[int] = (32, 64, 128) encoder_channels: Tuple[int] = (32, 64, 128)
encoder_layers_per_block: int = 1 encoder_layers_per_block: int = 1
@dataclass @dataclass
class DataConfig: class DataConfig:
"""数据相关配置""" """数据相关配置"""
@ -53,14 +54,15 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 1 batch_size: int = 1
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 100 num_epochs: int = 10
learning_rate: float = 1e-4 learning_rate: float = 1e-4
min_lr: float = 1e-6 min_lr: float = 1e-6
weight_decay: float = 0.01 weight_decay: float = 0.01
# 梯度和损失相关 # 梯度和损失相关
grad_weight: float = 0.1
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
clamping_distance: float = 0.1
# 学习率调度器参数 # 学习率调度器参数
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step'] lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step']
@ -70,6 +72,10 @@ class TrainConfig:
save_freq: int = 10 # 每多少个epoch保存一次 save_freq: int = 10 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次 val_freq: int = 1 # 每多少个epoch验证一次
@dataclass
class TestConfig:
vis_freq: int = 100
@dataclass @dataclass
class LogConfig: class LogConfig:
"""日志相关配置""" """日志相关配置"""
@ -91,6 +97,7 @@ class Config:
model: ModelConfig = ModelConfig() model: ModelConfig = ModelConfig()
data: DataConfig = DataConfig() data: DataConfig = DataConfig()
train: TrainConfig = TrainConfig() train: TrainConfig = TrainConfig()
test: TestConfig = TestConfig()
log: LogConfig = LogConfig() log: LogConfig = LogConfig()
def __post_init__(self): def __post_init__(self):
@ -104,12 +111,14 @@ class Config:
model_config = ModelConfig(**config_dict.get('model', {})) model_config = ModelConfig(**config_dict.get('model', {}))
data_config = DataConfig(**config_dict.get('data', {})) data_config = DataConfig(**config_dict.get('data', {}))
train_config = TrainConfig(**config_dict.get('train', {})) train_config = TrainConfig(**config_dict.get('train', {}))
test_config = TestConfig(**config_dict.get('test', {}))
log_config = LogConfig(**config_dict.get('log', {})) log_config = LogConfig(**config_dict.get('log', {}))
return cls( return cls(
model=model_config, model=model_config,
data=data_config, data=data_config,
train=train_config, train=train_config,
test=test_config,
log=log_config log=log_config
) )

70
brep2sdf/networks/loss.py

@ -0,0 +1,70 @@
import torch
import torch.nn as nn
from brep2sdf.config.default_config import get_default_config
class Brep2SDFLoss:
"""解释Brep2SDF的loss设计原理"""
def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1):
self.l1_loss = nn.L1Loss(reduction='sum')
self.enforce_minmax = enforce_minmax
self.minT = -clamping_distance
self.maxT = clamping_distance
def __call__(self, pred_sdf, gt_sdf):
"""使类可直接调用"""
return self.forward(pred_sdf, gt_sdf)
def forward(self, pred_sdf, gt_sdf):
"""
pred_sdf: 预测的SDF值
gt_sdf: 真实的SDF值
latent_vecs: 形状编码, 来自 brep
epoch: 当前训练轮次
"""
if self.enforce_minmax:
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT)
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT)
# 1. L1 Loss的优势
# - 对异常值更鲁棒
# - 能更好地保持表面细节
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
return base_loss
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint

24
brep2sdf/train.py

@ -4,7 +4,8 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF, sdf_loss from brep2sdf.networks.network import BRepToSDF
from brep2sdf.networks.loss import Brep2SDFLoss
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config from brep2sdf.config.default_config import get_default_config, load_config
import wandb import wandb
@ -45,6 +46,11 @@ class Trainer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clamping_distance = self.config.train.clamping_distance
self.criterion = Brep2SDFLoss(
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
@ -126,11 +132,9 @@ class Trainer:
) )
# 计算损失 # 计算损失
loss = sdf_loss( loss = self.criterion(
pred_sdf, pred_sdf=pred_sdf,
gt_sdf, gt_sdf=gt_sdf,
points,
grad_weight=self.config.train.grad_weight
) )
# 反向传播和优化 # 反向传播和优化
@ -185,11 +189,9 @@ class Trainer:
) )
# 计算损失 # 计算损失
loss = sdf_loss( loss = self.criterion(
pred_sdf, pred_sdf=pred_sdf,
gt_sdf, gt_sdf=gt_sdf,
points,
grad_weight=self.config.train.grad_weight
) )
total_loss += loss.item() total_loss += loss.item()

Loading…
Cancel
Save