Browse Source

fix:修复loss参数

main
mckay 4 months ago
parent
commit
b2b42918f4
  1. 13
      brep2sdf/config/default_config.py
  2. 70
      brep2sdf/networks/loss.py
  3. 24
      brep2sdf/train.py

13
brep2sdf/config/default_config.py

@ -23,6 +23,7 @@ class ModelConfig:
encoder_channels: Tuple[int] = (32, 64, 128)
encoder_layers_per_block: int = 1
@dataclass
class DataConfig:
"""数据相关配置"""
@ -53,14 +54,15 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 1
num_workers: int = 4
num_epochs: int = 100
num_epochs: int = 10
learning_rate: float = 1e-4
min_lr: float = 1e-6
weight_decay: float = 0.01
# 梯度和损失相关
grad_weight: float = 0.1
max_grad_norm: float = 1.0
clamping_distance: float = 0.1
# 学习率调度器参数
lr_scheduler: str = 'cosine' # ['cosine', 'linear', 'step']
@ -70,6 +72,10 @@ class TrainConfig:
save_freq: int = 10 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次
@dataclass
class TestConfig:
vis_freq: int = 100
@dataclass
class LogConfig:
"""日志相关配置"""
@ -91,6 +97,7 @@ class Config:
model: ModelConfig = ModelConfig()
data: DataConfig = DataConfig()
train: TrainConfig = TrainConfig()
test: TestConfig = TestConfig()
log: LogConfig = LogConfig()
def __post_init__(self):
@ -104,12 +111,14 @@ class Config:
model_config = ModelConfig(**config_dict.get('model', {}))
data_config = DataConfig(**config_dict.get('data', {}))
train_config = TrainConfig(**config_dict.get('train', {}))
test_config = TestConfig(**config_dict.get('test', {}))
log_config = LogConfig(**config_dict.get('log', {}))
return cls(
model=model_config,
data=data_config,
train=train_config,
test=test_config,
log=log_config
)

70
brep2sdf/networks/loss.py

@ -0,0 +1,70 @@
import torch
import torch.nn as nn
from brep2sdf.config.default_config import get_default_config
class Brep2SDFLoss:
"""解释Brep2SDF的loss设计原理"""
def __init__(self, enforce_minmax: bool=True, clamping_distance: float = 0.1):
self.l1_loss = nn.L1Loss(reduction='sum')
self.enforce_minmax = enforce_minmax
self.minT = -clamping_distance
self.maxT = clamping_distance
def __call__(self, pred_sdf, gt_sdf):
"""使类可直接调用"""
return self.forward(pred_sdf, gt_sdf)
def forward(self, pred_sdf, gt_sdf):
"""
pred_sdf: 预测的SDF值
gt_sdf: 真实的SDF值
latent_vecs: 形状编码, 来自 brep
epoch: 当前训练轮次
"""
if self.enforce_minmax:
pred_sdf = torch.clamp(pred_sdf, self.minT, self.maxT)
gt_sdf = torch.clamp(gt_sdf, self.minT, self.maxT)
# 1. L1 Loss的优势
# - 对异常值更鲁棒
# - 能更好地保持表面细节
base_loss = self.l1_loss(pred_sdf, gt_sdf) / pred_sdf.shape[0]
return base_loss
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
return l1_loss + grad_weight * grad_constraint

24
brep2sdf/train.py

@ -4,7 +4,8 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.network import BRepToSDF, sdf_loss
from brep2sdf.networks.network import BRepToSDF
from brep2sdf.networks.loss import Brep2SDFLoss
from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config
import wandb
@ -45,6 +46,11 @@ class Trainer:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clamping_distance = self.config.train.clamping_distance
self.criterion = Brep2SDFLoss(
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)
use_pin_memory = self.device.type == 'cuda' # 根据设备类型决定是否使用pin_memory
logger.info(f"Using device: {self.device}")
@ -126,11 +132,9 @@ class Trainer:
)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
points,
grad_weight=self.config.train.grad_weight
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
# 反向传播和优化
@ -185,11 +189,9 @@ class Trainer:
)
# 计算损失
loss = sdf_loss(
pred_sdf,
gt_sdf,
points,
grad_weight=self.config.train.grad_weight
loss = self.criterion(
pred_sdf=pred_sdf,
gt_sdf=gt_sdf,
)
total_loss += loss.item()

Loading…
Cancel
Save