diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 5397115..395b116 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -65,7 +65,7 @@ class Decoder(nn.Module): self.activation = nn.ReLU() else: #siren - self.activation = nn.SiLU() + self.activation = nn.ReLU() self.final_activation = nn.Tanh() def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 0ede43f..bc26caf 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -2,9 +2,19 @@ import torch import torch.nn as nn from .octree import OctreeNode -from .feature_volume import PatchFeatureVolume +from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder from brep2sdf.utils.logger import logger +from torchviz import make_dot + +class Sine(nn.Module): + def __init(self): + super().__init__() + + def forward(self, input): + # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 + return torch.sin(30 * input) + class Encoder(nn.Module): def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32): """ @@ -21,6 +31,7 @@ class Encoder(nn.Module): resolutions = self._batch_calculate_resolution(volume_bboxs) # 初始化多个特征体积 + ''' self.feature_volumes = nn.ModuleList([ PatchFeatureVolume( bbox=bbox, @@ -28,17 +39,18 @@ class Encoder(nn.Module): feature_dim=feature_dim ) for i, bbox in enumerate(volume_bboxs) ]) + ''' + self.feature_volumes = nn.ModuleList([ + SimpleFeatureEncoder( + input_dim=3, feature_dim=feature_dim + ) for i, bbox in enumerate(volume_bboxs) + ]) + self.background = self.simple_encoder = nn.Sequential( nn.Linear(3, 256), nn.BatchNorm1d(256), nn.ReLU(), - nn.Linear(256, 512), - nn.BatchNorm1d(512), - nn.ReLU(), - nn.Linear(512, 256), - nn.BatchNorm1d(256), - nn.ReLU(), nn.Linear(256, feature_dim) ) print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}") @@ -91,7 +103,7 @@ class Encoder(nn.Module): if mask.any(): # 获取对应volume的特征 (M, D) features = volume.forward(query_points[mask]) - all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask] + all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features return all_features @@ -120,7 +132,8 @@ class Encoder(nn.Module): """ # 获取 patch 特征 patch_features = self.feature_volumes[patch_id].forward(surf_points) - + #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) + #dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 return patch_features def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor: diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index 8b13df2..0e5e94b 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -73,4 +73,29 @@ class PatchFeatureVolume(nn.Module): features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D) # 加权求和 (B,D) - return torch.einsum('bnd,bn->bd', features, weights) \ No newline at end of file + return torch.einsum('bnd,bn->bd', features, weights) + + +class SimpleFeatureEncoder(nn.Module): + def __init__(self, input_dim=3, feature_dim=64): + super(SimpleFeatureEncoder, self).__init__() + # 定义一个多层感知机作为编码器 + self.encoder = nn.Sequential( + nn.Linear(input_dim, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 512), + nn.ReLU(inplace=True), + nn.Linear(512, feature_dim) + ) + + def forward(self, query_points: torch.Tensor) -> torch.Tensor: + """ + Args: + query_points: 形状为 (B, 3) 的查询点坐标 + + Returns: + 形状为 (B, feature_dim) 的特征向量 + """ + return self.encoder(query_points) \ No newline at end of file diff --git a/brep2sdf/networks/loss.py b/brep2sdf/networks/loss.py index 5f07718..8543024 100644 --- a/brep2sdf/networks/loss.py +++ b/brep2sdf/networks/loss.py @@ -67,7 +67,14 @@ class LossManager: # NOTE 源代码 这里还有复杂逻辑 # 计算分支梯度 branch_grad = gradient(mnfld_pnts, pred_sdfs) # 计算分支梯度 - + ''' + logger.info(f"branch_grad:{branch_grad}") + logger.info(f"mnfld_pnts:{mnfld_pnts}, shape:{mnfld_pnts.shape}") + logger.info(f"pred_sdfs:{pred_sdfs}") + logger.print_tensor_stats("mnfld_pnts",mnfld_pnts) + logger.print_tensor_stats("pred_sdfs",pred_sdfs) + logger.print_tensor_stats("mnfld_pnts[2]",mnfld_pnts[:,2]) + ''' # 计算法线损失 normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失 @@ -203,7 +210,7 @@ class LossManager: manifold_loss = self.position_loss(mnfld_pred, gt_sdfs) # 计算法线损失 - normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) + #normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred) #logger.gpu_memory_stats("计算法线损失后") @@ -217,7 +224,7 @@ class LossManager: # 汇总损失 loss_details = { "manifold": self.weights["manifold"] * manifold_loss, - "normals": self.weights["normals"] * normals_loss + #"normals": self.weights["normals"] * normals_loss } # 计算总损失 diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index bfbe240..033e07a 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -77,6 +77,7 @@ class Net(nn.Module): self.decoder = Decoder( d_in=feature_dim, dims_sdf=[decoder_hidden_dim] * decoder_num_layers, + #skip_in=(3,), geometric_init=False, beta=5 ) @@ -216,7 +217,7 @@ def gradient(inputs, outputs): create_graph=True, retain_graph=True, only_inputs=True, - allow_unused=True # 新增异常处理 + allow_unused=False # 新增异常处理 )[0] # 修正维度切片方式 diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 8b71d7c..e600da5 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -4,7 +4,7 @@ import time import os import numpy as np import argparse - +from torchviz import make_dot from brep2sdf.config.default_config import get_default_config from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor @@ -324,6 +324,8 @@ class Trainer: logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") + dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) + dot.render("forward_graph1", format="png") # 这会保存计算图为png格式 return total_loss # 对于单批次训练,直接返回当前损失 @@ -478,6 +480,9 @@ class Trainer: subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"] logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)])) + dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)])) + dot.render("forward_graph2", format="png") # 这会保存计算图为png格式 + avg_loss = sum(losses) / len(losses) logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}") @@ -659,6 +664,8 @@ class Trainer: _nonmnfld_face_indices_mask[start_idx:end_idx], _nonmnfld_operator[start_idx:end_idx] ) + dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)])) + dot.render("forward_graph3", format="png") # 这会保存计算图为png格式 #logger.print_tensor_stats("psdf",psdf) #logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts) @@ -731,6 +738,7 @@ class Trainer: f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") + return total_loss # 对于单批次训练,直接返回当前损失 def train_epoch(self, epoch: int,resample:bool=True) -> float: