diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 9dcf142..fc0f4b1 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -49,7 +49,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 1000 + num_epochs: int = 100 learning_rate: float = 0.1 min_lr: float = 1e-5 weight_decay: float = 0.01 diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 824c516..0624f7b 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -108,19 +108,16 @@ class Decoder(nn.Module): ''' # 直接使用输入的特征矩阵,因为形状已经是 (S, D) x = feature_matrix - logger.debug(f"decoder-x:{x}") for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt x = lin(x) - logger.debug(f"decoder-x-lin:{x}") if layer < self.sdf_layers - 2: x = self.activation(x) output_value = x # 所有 f 的值 - logger.debug(f"decoder-output:{output_value}") # 调整输出形状为 (S) f = output_value.squeeze(-1) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 541accb..aedbb90 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -105,7 +105,6 @@ class Encoder(nn.Module): for idx, volume in enumerate(self.feature_volumes): if idx == patch_id: patch_features = volume.forward(surf_points) - break # 获取背景场特征 background_features = self.background.forward(surf_points) diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 43b65bb..fe269eb 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -111,10 +111,13 @@ class LossManager: return correction_loss - def compute_loss(self, points, + def compute_loss(self, + mnfld_pnts, + nonmnfld_pnts, normals, gt_sdfs, - pred_sdfs): + mnfld_pred, + nonmnfld_pred): """ 计算流型损失的逻辑 @@ -123,20 +126,34 @@ class LossManager: """ # 强制类型转换确保一致性 normals = normals.to(torch.float32) - pred_sdfs = pred_sdfs.to(torch.float32) + mnfld_pred = mnfld_pred.to(torch.float32) gt_sdfs = gt_sdfs.to(torch.float32) # 计算流形损失 - manifold_loss = self.position_loss(pred_sdfs, gt_sdfs) + manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) # 计算法线损失 - normals_loss = self.normals_loss(normals, points, pred_sdfs) + normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) #logger.gpu_memory_stats("计算法线损失后") + # 计算Eikonal损失 + eikonal_loss = self.eikonal_loss(nonmnfld_pnts, nonmnfld_pred) + + # 计算离表面损失 + offsurface_loss = self.offsurface_loss(nonmnfld_pnts, nonmnfld_pred) + + # 计算一致性损失 + #onsistency_loss = self.consistency_loss(mnfld_pnts, mnfld_pred, all_fi) + + # 计算修正损失 + #correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi) + # 汇总损失 loss_details = { "manifold": self.weights["manifold"] * manifold_loss, "normals": self.weights["normals"] * normals_loss, + "eikonal": self.weights["eikonal"] * eikonal_loss, + "offsurface": self.weights["offsurface"] * offsurface_loss } # 计算总损失 diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index d07a770..f8e7736 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -136,7 +136,6 @@ class Net(nn.Module): surf_points (P, S): return (P, S) """ - logger.debug(surf_points) feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) f_i = self.decoder.forward_training_volumes(feature_mat) return f_i.squeeze() diff --git a/brep2sdf/networks/sample.py b/brep2sdf/networks/sample.py new file mode 100644 index 0000000..fa4b02c --- /dev/null +++ b/brep2sdf/networks/sample.py @@ -0,0 +1,22 @@ +import torch + + +class NormalPerPoint(): + + def __init__(self, global_sigma, local_sigma=0.01): + self.global_sigma = global_sigma + self.local_sigma = local_sigma + + def get_points(self, pc_input, local_sigma=None): + batch_size, sample_size, dim = pc_input.shape + + if local_sigma is not None: + sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1)) + else: + sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) + + sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma + + sample = torch.cat([sample_local, sample_global], dim=1) + + return sample \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 87a1fee..e753e5f 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -12,6 +12,7 @@ from brep2sdf.networks.network import Net from brep2sdf.networks.octree import OctreeNode from brep2sdf.networks.loss import LossManager from brep2sdf.networks.patch_graph import PatchGraph +from brep2sdf.networks.sample import NormalPerPoint from brep2sdf.utils.logger import logger @@ -142,6 +143,11 @@ class Trainer: self.loss_manager = LossManager(ablation="none") logger.gpu_memory_stats("训练器初始化后") + self.sampler = NormalPerPoint( + global_sigma=0.1, # 全局采样标准差 + local_sigma=0.01 # 局部采样标准差 + ) + logger.info(f"初始化完成,正在处理模型 {self.model_name}") @@ -200,7 +206,9 @@ class Trainer: total_loss = 0.0 total_loss_details = { "manifold": 0.0, - "normals": 0.0 + "normals": 0.0, + "eikonal": 0.0, + "offsurface": 0.0 } accumulated_loss = 0.0 # 新增:用于累积多个step的loss @@ -208,19 +216,21 @@ class Trainer: self.optimizer.zero_grad() for step, surf_points in enumerate(self.data['surf_ncs']): - points = torch.tensor(surf_points, device=self.device) - gt_sdf = torch.zeros(points.shape[0], device=self.device) + mnfld_points = torch.tensor(surf_points, device=self.device) + nonmnfld_pnts = self.sampler.get_points(mnfld_points.unsqueeze(0)).squeeze(0) # 生成非流形点 + gt_sdf = torch.zeros(mnfld_points.shape[0], device=self.device) normals = None if args.use_normal: normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device) # --- 准备模型输入,启用梯度 --- - points.requires_grad_(True) # 在检查之后启用梯度 + mnfld_points.requires_grad_(True) # 在检查之后启用梯度 + nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度 # --- 前向传播 --- self.optimizer.zero_grad() - pred_sdf = self.model.forward_training_volumes(points, step) - logger.debug(f"pred_sdf:{pred_sdf}") + mnfld_pred = self.model.forward_training_volumes(mnfld_points, step) + nonmnfld_pred = self.model.forward_training_volumes(nonmnfld_pnts, step) if self.debug_mode: # --- 检查前向传播的输出 --- @@ -230,10 +240,12 @@ class Trainer: try: if args.use_normal: loss, loss_details = self.loss_manager.compute_loss( - points, + mnfld_points, + nonmnfld_pnts, normals, gt_sdf, - pred_sdf + mnfld_pred, + nonmnfld_pred ) else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) @@ -276,6 +288,7 @@ class Trainer: f'Loss: {total_loss:.6f}') logger.info(f"Loss Details: {total_loss_details}") return total_loss # 返回平均损失而非累计值 + def train_epoch(self, epoch: int) -> float: