Browse Source

增加非流型点loss处理

final
mckay 3 months ago
parent
commit
09986f74fa
  1. 2
      brep2sdf/config/default_config.py
  2. 3
      brep2sdf/networks/decoder.py
  3. 1
      brep2sdf/networks/encoder.py
  4. 27
      brep2sdf/networks/loss.py
  5. 1
      brep2sdf/networks/network.py
  6. 22
      brep2sdf/networks/sample.py
  7. 29
      brep2sdf/train.py

2
brep2sdf/config/default_config.py

@ -49,7 +49,7 @@ class TrainConfig:
# 基本训练参数 # 基本训练参数
batch_size: int = 8 batch_size: int = 8
num_workers: int = 4 num_workers: int = 4
num_epochs: int = 1000 num_epochs: int = 100
learning_rate: float = 0.1 learning_rate: float = 0.1
min_lr: float = 1e-5 min_lr: float = 1e-5
weight_decay: float = 0.01 weight_decay: float = 0.01

3
brep2sdf/networks/decoder.py

@ -108,19 +108,16 @@ class Decoder(nn.Module):
''' '''
# 直接使用输入的特征矩阵,因为形状已经是 (S, D) # 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix x = feature_matrix
logger.debug(f"decoder-x:{x}")
for layer, lin in enumerate(self.sdf_modules): for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in: if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
x = lin(x) x = lin(x)
logger.debug(f"decoder-x-lin:{x}")
if layer < self.sdf_layers - 2: if layer < self.sdf_layers - 2:
x = self.activation(x) x = self.activation(x)
output_value = x # 所有 f 的值 output_value = x # 所有 f 的值
logger.debug(f"decoder-output:{output_value}")
# 调整输出形状为 (S) # 调整输出形状为 (S)
f = output_value.squeeze(-1) f = output_value.squeeze(-1)

1
brep2sdf/networks/encoder.py

@ -105,7 +105,6 @@ class Encoder(nn.Module):
for idx, volume in enumerate(self.feature_volumes): for idx, volume in enumerate(self.feature_volumes):
if idx == patch_id: if idx == patch_id:
patch_features = volume.forward(surf_points) patch_features = volume.forward(surf_points)
break
# 获取背景场特征 # 获取背景场特征
background_features = self.background.forward(surf_points) background_features = self.background.forward(surf_points)

27
brep2sdf/networks/loss.py

@ -111,10 +111,13 @@ class LossManager:
return correction_loss return correction_loss
def compute_loss(self, points, def compute_loss(self,
mnfld_pnts,
nonmnfld_pnts,
normals, normals,
gt_sdfs, gt_sdfs,
pred_sdfs): mnfld_pred,
nonmnfld_pred):
""" """
计算流型损失的逻辑 计算流型损失的逻辑
@ -123,20 +126,34 @@ class LossManager:
""" """
# 强制类型转换确保一致性 # 强制类型转换确保一致性
normals = normals.to(torch.float32) normals = normals.to(torch.float32)
pred_sdfs = pred_sdfs.to(torch.float32) mnfld_pred = mnfld_pred.to(torch.float32)
gt_sdfs = gt_sdfs.to(torch.float32) gt_sdfs = gt_sdfs.to(torch.float32)
# 计算流形损失 # 计算流形损失
manifold_loss = self.position_loss(pred_sdfs, gt_sdfs) manifold_loss = self.position_loss(mnfld_pred, gt_sdfs)
# 计算法线损失 # 计算法线损失
normals_loss = self.normals_loss(normals, points, pred_sdfs) normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#logger.gpu_memory_stats("计算法线损失后") #logger.gpu_memory_stats("计算法线损失后")
# 计算Eikonal损失
eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算离表面损失
offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred)
# 计算一致性损失
#onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi)
# 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
# 汇总损失 # 汇总损失
loss_details = { loss_details = {
"manifold": self.weights["manifold"] * manifold_loss, "manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss, "normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss
} }
# 计算总损失 # 计算总损失

1
brep2sdf/networks/network.py

@ -136,7 +136,6 @@ class Net(nn.Module):
surf_points (P, S): surf_points (P, S):
return (P, S) return (P, S)
""" """
logger.debug(surf_points)
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat) f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze() return f_i.squeeze()

22
brep2sdf/networks/sample.py

@ -0,0 +1,22 @@
import torch
class NormalPerPoint():
def __init__(self, global_sigma, local_sigma=0.01):
self.global_sigma = global_sigma
self.local_sigma = local_sigma
def get_points(self, pc_input, local_sigma=None):
batch_size, sample_size, dim = pc_input.shape
if local_sigma is not None:
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1))
else:
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma
sample = torch.cat([sample_local, sample_global], dim=1)
return sample

29
brep2sdf/train.py

@ -12,6 +12,7 @@ from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.octree import OctreeNode
from brep2sdf.networks.loss import LossManager from brep2sdf.networks.loss import LossManager
from brep2sdf.networks.patch_graph import PatchGraph from brep2sdf.networks.patch_graph import PatchGraph
from brep2sdf.networks.sample import NormalPerPoint
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
@ -142,6 +143,11 @@ class Trainer:
self.loss_manager = LossManager(ablation="none") self.loss_manager = LossManager(ablation="none")
logger.gpu_memory_stats("训练器初始化后") logger.gpu_memory_stats("训练器初始化后")
self.sampler = NormalPerPoint(
global_sigma=0.1, # 全局采样标准差
local_sigma=0.01 # 局部采样标准差
)
logger.info(f"初始化完成,正在处理模型 {self.model_name}") logger.info(f"初始化完成,正在处理模型 {self.model_name}")
@ -200,7 +206,9 @@ class Trainer:
total_loss = 0.0 total_loss = 0.0
total_loss_details = { total_loss_details = {
"manifold": 0.0, "manifold": 0.0,
"normals": 0.0 "normals": 0.0,
"eikonal": 0.0,
"offsurface": 0.0
} }
accumulated_loss = 0.0 # 新增:用于累积多个step的loss accumulated_loss = 0.0 # 新增:用于累积多个step的loss
@ -208,19 +216,21 @@ class Trainer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
for step, surf_points in enumerate(self.data['surf_ncs']): for step, surf_points in enumerate(self.data['surf_ncs']):
points = torch.tensor(surf_points, device=self.device) mnfld_points = torch.tensor(surf_points, device=self.device)
gt_sdf = torch.zeros(points.shape[0], device=self.device) nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点
gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device)
normals = None normals = None
if args.use_normal: if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
# --- 准备模型输入,启用梯度 --- # --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度 mnfld_points.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 --- # --- 前向传播 ---
self.optimizer.zero_grad() self.optimizer.zero_grad()
pred_sdf = self.model.forward_training_volumes(points, step) mnfld_pred = self.model.forward_training_volumes(mnfld_points, step)
logger.debug(f"pred_sdf:{pred_sdf}") nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step)
if self.debug_mode: if self.debug_mode:
# --- 检查前向传播的输出 --- # --- 检查前向传播的输出 ---
@ -230,10 +240,12 @@ class Trainer:
try: try:
if args.use_normal: if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss( loss, loss_details = self.loss_manager.compute_loss(
points, mnfld_points,
nonmnfld_pnts,
normals, normals,
gt_sdf, gt_sdf,
pred_sdf mnfld_pred,
nonmnfld_pred
) )
else: else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
@ -276,6 +288,7 @@ class Trainer:
f'Loss: {total_loss:.6f}') f'Loss: {total_loss:.6f}')
logger.info(f"Loss Details: {total_loss_details}") logger.info(f"Loss Details: {total_loss_details}")
return total_loss # 返回平均损失而非累计值 return total_loss # 返回平均损失而非累计值
def train_epoch(self, epoch: int) -> float: def train_epoch(self, epoch: int) -> float:

Loading…
Cancel
Save