Browse Source

增加非流型点loss处理

final
mckay 3 months ago
parent
commit
09986f74fa
  1. 2
      brep2sdf/config/default_config.py
  2. 3
      brep2sdf/networks/decoder.py
  3. 1
      brep2sdf/networks/encoder.py
  4. 27
      brep2sdf/networks/loss.py
  5. 1
      brep2sdf/networks/network.py
  6. 22
      brep2sdf/networks/sample.py
  7. 29
      brep2sdf/train.py

2
brep2sdf/config/default_config.py

@ -49,7 +49,7 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 1000
num_epochs: int = 100
learning_rate: float = 0.1
min_lr: float = 1e-5
weight_decay: float = 0.01

3
brep2sdf/networks/decoder.py

@ -108,19 +108,16 @@ class Decoder(nn.Module):
'''
# 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix
logger.debug(f"decoder-x:{x}")
for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
x = lin(x)
logger.debug(f"decoder-x-lin:{x}")
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x # 所有 f 的值
logger.debug(f"decoder-output:{output_value}")
# 调整输出形状为 (S)
f = output_value.squeeze(-1)

1
brep2sdf/networks/encoder.py

@ -105,7 +105,6 @@ class Encoder(nn.Module):
for idx, volume in enumerate(self.feature_volumes):
if idx == patch_id:
patch_features = volume.forward(surf_points)
break
# 获取背景场特征
background_features = self.background.forward(surf_points)

27
brep2sdf/networks/loss.py

@ -111,10 +111,13 @@ class LossManager:
return correction_loss
def compute_loss(self, points,
def compute_loss(self,
mnfld_pnts,
nonmnfld_pnts,
normals,
gt_sdfs,
pred_sdfs):
mnfld_pred,
nonmnfld_pred):
"""
计算流型损失的逻辑
@ -123,20 +126,34 @@ class LossManager:
"""
# 强制类型转换确保一致性
normals = normals.to(torch.float32)
pred_sdfs = pred_sdfs.to(torch.float32)
mnfld_pred = mnfld_pred.to(torch.float32)
gt_sdfs = gt_sdfs.to(torch.float32)
# 计算流形损失
manifold_loss = self.position_loss(pred_sdfs, gt_sdfs)
manifold_loss = self.position_loss(mnfld_pred, gt_sdfs)
# 计算法线损失
normals_loss = self.normals_loss(normals, points, pred_sdfs)
normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#logger.gpu_memory_stats("计算法线损失后")
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算一致性损失
#onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
# 汇总损失
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss
}
# 计算总损失

1
brep2sdf/networks/network.py

@ -136,7 +136,6 @@ class Net(nn.Module):
surf_points (P, S):
return (P, S)
"""
logger.debug(surf_points)
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze()

22
brep2sdf/networks/sample.py

@ -0,0 +1,22 @@
import torch
class NormalPerPoint():
def __init__(self, global_sigma, local_sigma=0.01):
self.global_sigma = global_sigma
self.local_sigma = local_sigma
def get_points(self, pc_input, local_sigma=None):
batch_size, sample_size, dim = pc_input.shape
if local_sigma is not None:
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1))
else:
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma
sample = torch.cat([sample_local, sample_global], dim=1)
return sample

29
brep2sdf/train.py

@ -12,6 +12,7 @@ from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.networks.sample import NormalPerPoint
from brep2sdf.utils.logger import logger
@ -142,6 +143,11 @@ class Trainer:
self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后")
self.sampler = NormalPerPoint(
global_sigma=0.1, # 全局采样标准差
local_sigma=0.01 # 局部采样标准差
)
logger.info(f"初始化完成,正在处理模型 {self.model_name}")
@ -200,7 +206,9 @@ class Trainer:
total_loss = 0.0
total_loss_details = {
"manifold": 0.0,
"normals": 0.0
"normals": 0.0,
"eikonal": 0.0,
"offsurface": 0.0
}
accumulated_loss = 0.0 # 新增:用于累积多个step的loss
@ -208,19 +216,21 @@ class Trainer:
self.optimizer.zero_grad()
for step, surf_points in enumerate(self.data['surf_ncs']):
points = torch.tensor(surf_points, device=self.device)
gt_sdf = torch.zeros(points.shape[0], device=self.device)
mnfld_points = torch.tensor(surf_points, device=self.device)
nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
mnfld_points.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model.forward_training_volumes(points, step)
logger.debug(f"pred_sdf:{pred_sdf}")
mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
@ -230,10 +240,12 @@ class Trainer:
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss(
points,
mnfld_points,
nonmnfld_pnts,
normals,
gt_sdf,
pred_sdf
mnfld_pred,
nonmnfld_pred
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
@ -278,6 +290,7 @@ class Trainer:
return total_loss # 返回平均损失而非累计值
def train_epoch(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)

Loading…
Cancel
Save