Browse Source

增加伪sdf loss

final
mckay 1 month ago
parent
commit
d7227b91a2
  1. 18
      brep2sdf/config/default_config.py
  2. 11
      brep2sdf/networks/loss.py
  3. 23
      brep2sdf/networks/sample.py
  4. 9
      brep2sdf/train.py

18
brep2sdf/config/default_config.py

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple,List
@dataclass @dataclass
class ModelConfig: class ModelConfig:
@ -49,15 +49,21 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 100 num_epochs: int = 50
learning_rate: float = 0.1 learning_rate: float = 0.005
learning_rate_schedule: List = field(default_factory=lambda: [{
"Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子
"Initial": 0.005,
"Interval": 2000,
"Factor": 0.5
}])
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01
# 梯度和损失相关 # 梯度和损失相关
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
clamping_distance: float = 0.1 clamping_distance: float = 0.1
debug_mode: bool = True debug_mode: bool = False
accumulation_steps:int = 50 accumulation_steps:int = 50
# 学习率调度器参数 # 学习率调度器参数
@ -65,7 +71,7 @@ class TrainConfig:
warmup_epochs: int = 5 warmup_epochs: int = 5
# 保存和验证 # 保存和验证
save_freq: int = 100 # 每多少个epoch保存一次 save_freq: int = 10 # 每多少个epoch保存一次
val_freq: int = 1 # 每多少个epoch验证一次 val_freq: int = 1 # 每多少个epoch验证一次
# 保存路径 # 保存路径

11
brep2sdf/networks/loss.py

@ -5,13 +5,14 @@ from brep2sdf.utils.logger import logger
class LossManager: class LossManager:
def __init__(self, ablation, **condition_kwargs): def __init__(self, ablation, **condition_kwargs):
self.weights = { self.weights = {
"manifold": 1, "manifold": 10,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的 "feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1, "normals": 1,
"eikonal": 1, "eikonal": 1,
"offsurface": 1, "offsurface": 1,
"consistency": 1, "consistency": 1,
"correction": 1, "correction": 1,
"psdf": 10
} }
self.condition_kwargs = condition_kwargs self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用 self.ablation = ablation # 消融实验用
@ -117,7 +118,8 @@ class LossManager:
normals, normals,
gt_sdfs, gt_sdfs,
mnfld_pred, mnfld_pred,
nonmnfld_pred): nonmnfld_pred,
psdfs):
""" """
计算流型损失的逻辑 计算流型损失的逻辑
@ -148,12 +150,15 @@ class LossManager:
# 计算修正损失 # 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
psdf_loss = self.position_loss(nonmnfld_pred, psdfs)
# 汇总损失 # 汇总损失
loss_details = { loss_details = {
"manifold": self.weights["manifold"] * manifold_loss, "manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss, "normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss, "eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss "offsurface": self.weights["offsurface"] * offsurface_loss,
"psdf":self.weights["psdf"] * psdf_loss
} }
# 计算总损失 # 计算总损失

23
brep2sdf/networks/sample.py

@ -3,7 +3,7 @@ import torch
class NormalPerPoint(): class NormalPerPoint():
def __init__(self, global_sigma, local_sigma=0.01): def __init__(self, global_sigma, local_sigma=0.1):
self.global_sigma = global_sigma self.global_sigma = global_sigma
self.local_sigma = local_sigma self.local_sigma = local_sigma
@ -20,3 +20,24 @@ class NormalPerPoint():
sample = torch.cat([sample_local, sample_global], dim=1) sample = torch.cat([sample_local, sample_global], dim=1)
return sample return sample
def get_norm_points(self, pc_input, normals, local_sigma=None):
"""
返回沿法线方向偏移的点以及对应的伪 SDF PSDF
:param pc_input: 输入点云形状为 (sample_size, 3)
:param normals: 点云的法线形状为 (sample_size, 3)
:param local_sigma: 局部偏移的标准差
:return: 偏移后的点 (sample_size, dim), SDF
"""
sample_size, dim = pc_input.shape
# 生成随机位移值
if local_sigma is not None:
psdf = torch.randn(sample_size, device=pc_input.device) * local_sigma
else:
psdf = torch.randn(sample_size, device=pc_input.device) * self.local_sigma
# 沿法线方向偏移
sample = pc_input + normals * psdf.unsqueeze(-1)
return sample, psdf

9
brep2sdf/train.py

@ -315,7 +315,7 @@ class Trainer:
start_idx = batch_idx * batch_size start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points) end_idx = min((batch_idx + 1) * batch_size, num_points)
mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 mnfld_pnts = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值
normals = None normals = None
if args.use_normal: if args.use_normal:
@ -323,7 +323,11 @@ class Trainer:
logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.")
return float('inf') return float('inf')
normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线
nonmnfld_pnts,psdf = self.sampler.get_norm_points(mnfld_pnts,normals) # 生成非流形点
logger.debug((mnfld_pnts,nonmnfld_pnts,psdf))
else:
nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0)).squeeze(0) # 生成非流形点
# 执行检查 # 执行检查
if self.debug_mode: if self.debug_mode:
if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf') if check_tensor(mnfld_pnts, "Input Points", epoch, step): return float('inf')
@ -372,7 +376,8 @@ class Trainer:
normals, # 传递检查过的 normals normals, # 传递检查过的 normals
gt_sdf, gt_sdf,
mnfld_pred, mnfld_pred,
nonmnfld_pred nonmnfld_pred,
psdf
) )
else: else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)

Loading…
Cancel
Save