Browse Source

有形状了

final
mckay 1 month ago
parent
commit
2109c4a6f3
  1. 2
      brep2sdf/networks/decoder.py
  2. 31
      brep2sdf/networks/encoder.py
  3. 27
      brep2sdf/networks/feature_volume.py
  4. 13
      brep2sdf/networks/loss.py
  5. 3
      brep2sdf/networks/network.py
  6. 10
      brep2sdf/train.py

2
brep2sdf/networks/decoder.py

@ -65,7 +65,7 @@ class Decoder(nn.Module):
self.activation = nn.ReLU()
else:
#siren
self.activation = nn.SiLU()
self.activation = nn.ReLU()
self.final_activation = nn.Tanh()
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:

31
brep2sdf/networks/encoder.py

@ -2,9 +2,19 @@ import torch
import torch.nn as nn
from .octree import OctreeNode
from .feature_volume import PatchFeatureVolume
from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder
from brep2sdf.utils.logger import logger
from torchviz import make_dot
class Sine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
return torch.sin(30 * input)
class Encoder(nn.Module):
def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32):
"""
@ -21,6 +31,7 @@ class Encoder(nn.Module):
resolutions = self._batch_calculate_resolution(volume_bboxs)
# 初始化多个特征体积
'''
self.feature_volumes = nn.ModuleList([
PatchFeatureVolume(
bbox=bbox,
@ -28,17 +39,18 @@ class Encoder(nn.Module):
feature_dim=feature_dim
) for i, bbox in enumerate(volume_bboxs)
])
'''
self.feature_volumes = nn.ModuleList([
SimpleFeatureEncoder(
input_dim=3, feature_dim=feature_dim
) for i, bbox in enumerate(volume_bboxs)
])
self.background = self.simple_encoder = nn.Sequential(
nn.Linear(3, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, feature_dim)
)
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
@ -91,7 +103,7 @@ class Encoder(nn.Module):
if mask.any():
# 获取对应volume的特征 (M, D)
features = volume.forward(query_points[mask])
all_features[mask, vol_id] = 0.7 * features + 0.3 * background_features[mask]
all_features[mask, vol_id] = 0.9 * background_features[mask] + 0.1 * features
return all_features
@ -120,7 +132,8 @@ class Encoder(nn.Module):
"""
# 获取 patch 特征
patch_features = self.feature_volumes[patch_id].forward(surf_points)
#dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters()))
#dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件
return patch_features
def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor:

27
brep2sdf/networks/feature_volume.py

@ -73,4 +73,29 @@ class PatchFeatureVolume(nn.Module):
features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D)
# 加权求和 (B,D)
return torch.einsum('bnd,bn->bd', features, weights)
return torch.einsum('bnd,bn->bd', features, weights)
class SimpleFeatureEncoder(nn.Module):
def __init__(self, input_dim=3, feature_dim=64):
super(SimpleFeatureEncoder, self).__init__()
# 定义一个多层感知机作为编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, feature_dim)
)
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""
Args:
query_points: 形状为 (B, 3) 的查询点坐标
Returns:
形状为 (B, feature_dim) 的特征向量
"""
return self.encoder(query_points)

13
brep2sdf/networks/loss.py

@ -67,7 +67,14 @@ class LossManager:
# NOTE 源代码 这里还有复杂逻辑
# 计算分支梯度
branch_grad = gradient(mnfld_pnts, pred_sdfs) # 计算分支梯度
'''
logger.info(f"branch_grad:{branch_grad}")
logger.info(f"mnfld_pnts:{mnfld_pnts}, shape:{mnfld_pnts.shape}")
logger.info(f"pred_sdfs:{pred_sdfs}")
logger.print_tensor_stats("mnfld_pnts",mnfld_pnts)
logger.print_tensor_stats("pred_sdfs",pred_sdfs)
logger.print_tensor_stats("mnfld_pnts[2]",mnfld_pnts[:,2])
'''
# 计算法线损失
normals_loss = (((branch_grad - normals).abs()).norm(2, dim=1)).mean() # 计算法线损失
@ -203,7 +210,7 @@ class LossManager:
manifold_loss = self.position_loss(mnfld_pred, gt_sdfs)
# 计算法线损失
normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#normals_loss = self.normals_loss(normals, mnfld_pnts, mnfld_pred)
#logger.gpu_memory_stats("计算法线损失后")
@ -217,7 +224,7 @@ class LossManager:
# 汇总损失
loss_details = {
"manifold": self.weights["manifold"] * manifold_loss,
"normals": self.weights["normals"] * normals_loss
#"normals": self.weights["normals"] * normals_loss
}
# 计算总损失

3
brep2sdf/networks/network.py

@ -77,6 +77,7 @@ class Net(nn.Module):
self.decoder = Decoder(
d_in=feature_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
#skip_in=(3,),
geometric_init=False,
beta=5
)
@ -216,7 +217,7 @@ def gradient(inputs, outputs):
create_graph=True,
retain_graph=True,
only_inputs=True,
allow_unused=True # 新增异常处理
allow_unused=False # 新增异常处理
)[0]
# 修正维度切片方式

10
brep2sdf/train.py

@ -4,7 +4,7 @@ import time
import os
import numpy as np
import argparse
from torchviz import make_dot
from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,prepare_sdf_data, print_data_distribution, check_tensor
@ -324,6 +324,8 @@ class Trainer:
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
dot.render("forward_graph1", format="png") # 这会保存计算图为png格式
return total_loss # 对于单批次训练,直接返回当前损失
@ -478,6 +480,9 @@ class Trainer:
subloss_names = ["manifold", "normals", "eikonal", "offsurface", "psdf"]
logger.info(" ".join([f"{name}: {weighted_avg[i].item():.6f}" for i, name in enumerate(subloss_names)]))
dot = make_dot((mnfld_pred, nonmnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts), ('nonmnfld_pnts', nonmnfld_pnts)]))
dot.render("forward_graph2", format="png") # 这会保存计算图为png格式
avg_loss = sum(losses) / len(losses)
logger.info(f"Total Loss: {total_loss:.6f} | Avg Loss: {avg_loss:.6f}")
@ -659,6 +664,8 @@ class Trainer:
_nonmnfld_face_indices_mask[start_idx:end_idx],
_nonmnfld_operator[start_idx:end_idx]
)
dot = make_dot((mnfld_pred), params=dict(list(self.model.named_parameters()) + [('mnfld_pnts', mnfld_pnts)]))
dot.render("forward_graph3", format="png") # 这会保存计算图为png格式
#logger.print_tensor_stats("psdf",psdf)
#logger.print_tensor_stats("nonmnfld_pnts",nonmnfld_pnts)
@ -731,6 +738,7 @@ class Trainer:
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch(self, epoch: int,resample:bool=True) -> float:

Loading…
Cancel
Save