Browse Source

refactor: 模型相关函数拆分

main
mckay 7 months ago
parent
commit
5b98c59270
  1. 413
      brep2sdf/networks/decoder.py
  2. 291
      brep2sdf/networks/network.py
  3. 2
      brep2sdf/train.py

413
brep2sdf/networks/decoder.py

@ -1,383 +1,42 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataclasses import dataclass import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
from typing import Dict, Optional, Tuple, Union
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
class Decoder1D(nn.Module): class SDFHead(nn.Module):
def __init__( """SDF预测头"""
self, def __init__(self, embed_dim: int = 768*2):
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
):
'''
这是第一阶段的解码器用于处理B-rep特征
包含三个主要部分
conv_in: 输入卷积层处理初始特征
mid_block: 中间处理块
up_blocks: 上采样块序列
支持梯度检查点功能gradient checkpointing以节省内存
输出维度: [B, C, L]
# NOTE:
1. 移除了分片(slicing)和平铺(tiling)功能
2. 直接使用mode()而不是sample()获取潜在向量
3. 简化了编码过程只保留核心功能
4. 返回确定性的潜在向量而不是分布
'''
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim//2),
self.conv_in = nn.Conv1d( nn.LayerNorm(embed_dim//2),
in_channels, nn.ReLU(),
block_out_channels[-1], nn.Linear(embed_dim//2, embed_dim//4),
kernel_size=3, nn.LayerNorm(embed_dim//4),
stride=1, nn.ReLU(),
padding=1, nn.Linear(embed_dim//4, 1),
) nn.Tanh()
)
self.mid_block = None
self.up_blocks = nn.ModuleList([]) def forward(self, x):
return self.mlp(x)
temb_channels = in_channels if norm_type == "spatial" else None
class SDFTransformer(nn.Module):
# mid """SDF Transformer编码器"""
self.mid_block = UNetMidBlock1D( def __init__(self, embed_dim: int = 768, num_layers: int = 6):
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlock1D(
in_channels=prev_output_channel,
out_channels=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None):
sample = z
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class SDFDecoder(nn.Module):
def __init__(
self,
latent_size,
dims,
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=None,
use_tanh=False,
latent_dropout=False,
):
'''
这是第二阶段的解码器用于生成SDF值
使用多层MLP结构
特点
支持在不同层注入latent信息通过latent_in参数
可以在每层添加xyz坐标通过xyz_in_all参数
支持权重归一化和dropout
输入维度: [N, latent_size+3]
输出维度: [N, 1]
'''
super(SDFDecoder, self).__init__()
def make_sequence():
return []
dims = [latent_size + 3] + dims + [1]
self.num_layers = len(dims)
self.norm_layers = norm_layers
self.latent_in = latent_in
self.latent_dropout = latent_dropout
if self.latent_dropout:
self.lat_dp = nn.Dropout(0.2)
self.xyz_in_all = xyz_in_all
self.weight_norm = weight_norm
for layer in range(0, self.num_layers - 1):
if layer + 1 in latent_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
if self.xyz_in_all and layer != self.num_layers - 2:
out_dim -= 3
if weight_norm and layer in self.norm_layers:
setattr(
self,
"lin" + str(layer),
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
)
else:
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
if (
(not weight_norm)
and self.norm_layers is not None
and layer in self.norm_layers
):
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
self.use_tanh = use_tanh
if use_tanh:
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.dropout_prob = dropout_prob
self.dropout = dropout
self.th = nn.Tanh()
# input: N x (L+3)
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x
class BRep2SdfDecoder(nn.Module):
def __init__(
self,
latent_size=256,
feature_dims=[512, 512, 256, 128], # 特征解码器维度
sdf_dims=[512, 512, 512, 512], # SDF解码器维度
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
norm_type="group",
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=True,
use_tanh=True,
):
super().__init__() super().__init__()
layer = nn.TransformerEncoderLayer(
# 1. 特征解码器 (使用Decoder1D结构) d_model=embed_dim,
self.feature_decoder = Decoder1D( nhead=8,
in_channels=latent_size, dim_feedforward=1024,
out_channels=feature_dims[-1], dropout=0.1,
up_block_types=up_block_types, batch_first=True,
block_out_channels=feature_dims, norm_first=False # 修改这里:设置为False
layers_per_block=layers_per_block, )
norm_num_groups=norm_num_groups, self.transformer = nn.TransformerEncoder(layer, num_layers)
norm_type=norm_type
) def forward(self, x, mask=None):
return self.transformer(x, src_key_padding_mask=mask)
# 2. SDF解码器 (使用原始Decoder结构)
self.sdf_decoder = SDFDecoder(
latent_size=feature_dims[-1], # 使用特征解码器的输出维度
dims=sdf_dims,
dropout=dropout,
dropout_prob=dropout_prob,
norm_layers=norm_layers,
latent_in=latent_in,
weight_norm=weight_norm,
xyz_in_all=xyz_in_all,
use_tanh=use_tanh,
latent_dropout=False
)
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式)
self.feature_transform = nn.Sequential(
nn.Linear(feature_dims[-1], feature_dims[-1]),
nn.LayerNorm(feature_dims[-1]),
nn.SiLU()
)
def forward(self, latent, query_points, latent_embeds=None):
"""
Args:
latent: [B, C, L] B-rep特征
query_points: [B, N, 3] 查询点
latent_embeds: 可选的条件嵌入
Returns:
sdf: [B, N, 1] SDF值
"""
# 1. 特征解码
features = self.feature_decoder(latent, latent_embeds) # [B, C, L]
# 2. 特征转换
B, C, L = features.shape
features = features.permute(0, 2, 1) # [B, L, C]
features = self.feature_transform(features) # [B, L, C]
# 3. 准备SDF解码器输入
_, N, _ = query_points.shape
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C]
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3]
# 4. 合并特征和坐标
sdf_input = torch.cat([
features.reshape(B*N*L, -1), # [B*N*L, C]
query_points.reshape(B*N*L, -1) # [B*N*L, 3]
], dim=-1)
# 5. SDF生成
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1]
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1]
# 6. 聚合多尺度SDF
sdf = sdf.mean(dim=2) # [B, N, 1]
return sdf
# 使用示例
if __name__ == "__main__":
# 创建模型
decoder = BRepDecoder(
latent_size=256,
feature_dims=[512, 256, 128, 64],
sdf_dims=[512, 512, 512, 512],
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
dropout=None,
dropout_prob=0.0,
norm_layers=[0, 1, 2, 3],
latent_in=[4],
weight_norm=True,
xyz_in_all=True,
use_tanh=True
)
# 测试数据
batch_size = 4
seq_len = 32
num_points = 1000
latent = torch.randn(batch_size, 256, seq_len)
query_points = torch.randn(batch_size, num_points, 3)
latent_embeds = torch.randn(batch_size, 256)
# 前向传播
sdf = decoder(latent, query_points, latent_embeds)
print(f"Input latent shape: {latent.shape}")
print(f"Query points shape: {query_points.shape}")
print(f"Output SDF shape: {sdf.shape}")

291
brep2sdf/networks/network.py

@ -1,127 +1,202 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder import BRepEncoder import torch.nn.functional as F
from decoder import BRep2SdfDecoder from typing import Dict, Optional, Tuple, Union
from brep2sdf.config.default_config import get_default_config
class BRep2SDF(nn.Module): from brep2sdf.utils.logger import logger
def __init__(
self, from brep2sdf.networks.encoder import BRepFeatureEmbedder
# 编码器参数 from brep2sdf.networks.decoder import SDFHead, SDFTransformer
in_channels=3,
latent_size=256,
encoder_block_out_channels=(512, 256, 128, 64), class BRepToSDF(nn.Module):
# 解码器参数 def __init__(self, config=None):
decoder_feature_dims=(512, 256, 128, 64),
sdf_dims=(512, 512, 512, 512),
# 共享参数
layers_per_block=2,
norm_num_groups=32,
# SDF特定参数
dropout=None,
dropout_prob=0.0,
norm_layers=(0, 1, 2, 3),
latent_in=(4,),
weight_norm=True,
xyz_in_all=True,
use_tanh=True,
):
super().__init__() super().__init__()
# 获取配置
if config is None:
self.config = get_default_config()
else:
self.config = config
# 从配置中读取参数
self.embed_dim = self.config.model.embed_dim
self.brep_feature_dim = self.config.model.brep_feature_dim
self.latent_dim = self.config.model.latent_dim
self.use_cf = self.config.model.use_cf
# 1. 查询点编码器
self.query_encoder = nn.Sequential(
nn.Linear(3, self.embed_dim//4),
nn.LayerNorm(self.embed_dim//4),
nn.ReLU(),
nn.Linear(self.embed_dim//4, self.embed_dim//2),
nn.LayerNorm(self.embed_dim//2),
nn.ReLU(),
nn.Linear(self.embed_dim//2, self.embed_dim)
)
# 2. B-rep特征编码器
self.brep_embedder = BRepFeatureEmbedder()
# 3. 特征融合Transformer
self.transformer = SDFTransformer(
embed_dim=self.embed_dim,
num_layers=6 # 这个参数也可以移到配置文件中
)
# 4. SDF预测头
self.sdf_head = SDFHead(embed_dim=self.embed_dim*2)
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None):
"""B-rep到SDF的前向传播
Args:
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3]
edge_pos: 边位置 [B, max_face, max_edge, 6]
edge_mask: 边掩码 [B, max_face, max_edge]
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3]
surf_pos: 面位置 [B, max_face, 6]
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3]
query_points: 查询点 [B, num_queries, 3]
data_class: (可选) 类别标签
Returns:
sdf: 预测的SDF值 [B, num_queries, 1]
"""
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries
try:
# 确保query_points需要梯度
if not query_points.requires_grad:
query_points = query_points.detach().requires_grad_(True)
# 1. 编码器配置
encoder_config = type('Config', (), {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'encoder_params': {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
}
})()
# 2. 解码器配置
decoder_config = {
'latent_size': latent_size,
'feature_dims': decoder_feature_dims,
'sdf_dims': sdf_dims,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'dropout': dropout,
'dropout_prob': dropout_prob,
'norm_layers': norm_layers,
'latent_in': latent_in,
'weight_norm': weight_norm,
'xyz_in_all': xyz_in_all,
'use_tanh': use_tanh,
}
# 3. 创建编码器和解码器 # 1. B-rep特征编码
self.encoder = BRepEncoder(encoder_config) brep_features = self.brep_embedder(
self.decoder = BRep2SdfDecoder(**decoder_config) edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3]
edge_pos=edge_pos, # [B, max_face, max_edge, 6]
edge_mask=edge_mask, # [B, max_face, max_edge]
surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3]
surf_pos=surf_pos, # [B, max_face, 6]
vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3]
data_class=data_class
) # [B, max_face*(max_edge+1), embed_dim]
def encode(self, brep_model): # 2. 查询点编码
"""编码B-rep模型为潜在特征""" query_features = self.query_encoder(query_points) # [B, Q, embed_dim]
return self.encoder.encode(brep_model)
def decode(self, latent, query_points, latent_embeds=None): # 3. 提取全局特征
"""从潜在特征解码SDF值""" global_features = brep_features.mean(dim=1) # [B, embed_dim]
return self.decoder(latent, query_points, latent_embeds)
# 4. 为每个查询点准备特征
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim]
# 5. 连接查询点特征和全局特征
combined_features = torch.cat([
expanded_features, # [B, Q, embed_dim]
query_features # [B, Q, embed_dim]
], dim=-1) # [B, Q, embed_dim*2]
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
if not sdf.requires_grad:
logger.warning("SDF output does not require grad!")
def forward(self, brep_model, query_points):
"""完整的前向传播过程"""
# 1. 编码B-rep模型
latent = self.encode(brep_model)
if latent is None:
return None
# 2. 解码SDF值
sdf = self.decode(latent, query_points)
return sdf return sdf
# 使用示例 except Exception as e:
if __name__ == "__main__": logger.error(f"Error in BRepToSDF forward pass:")
# 创建模型 logger.error(f" Error message: {str(e)}")
model = BRep2SDF( logger.error(f" Input shapes:")
in_channels=3, logger.error(f" edge_ncs: {edge_ncs.shape}")
latent_size=256, logger.error(f" edge_pos: {edge_pos.shape}")
encoder_block_out_channels=(512, 256, 128, 64), logger.error(f" edge_mask: {edge_mask.shape}")
decoder_feature_dims=(512, 256, 128, 64), logger.error(f" surf_ncs: {surf_ncs.shape}")
sdf_dims=(512, 512, 512, 512), logger.error(f" surf_pos: {surf_pos.shape}")
layers_per_block=2, logger.error(f" vertex_pos: {vertex_pos.shape}")
norm_num_groups=32, logger.error(f" query_points: {query_points.shape}")
raise
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# 确保points需要梯度
if not points.requires_grad:
points = points.detach().requires_grad_(True)
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
try:
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True,
retain_graph=True,
allow_unused=True
)[0]
if grad is not None:
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
) )
else:
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
except Exception as e:
logger.warning(f"Gradient computation failed: {str(e)}")
grad_constraint = torch.tensor(0.0, device=pred_sdf.device)
# 测试数据 return l1_loss + grad_weight * grad_constraint
batch_size = 4
seq_len = 32
num_points = 1000
# 模拟B-rep模型数据 def main():
class MockBRep: # 获取配置
def __init__(self): config = get_default_config()
self.faces = [MockFace() for _ in range(10)]
self.edges = [MockEdge() for _ in range(20)]
class MockFace: # 初始化模型
def __init__(self): model = BRepToSDF(config=config)
self.center_point = torch.randn(3)
self.normal_vector = torch.randn(3)
self.surface_type = 0
self.edges = []
class MockEdge: # 从配置获取参数
def __init__(self): batch_size = config.train.batch_size
self.length = lambda: 1.0 max_face = config.data.max_face
self.point_at = lambda t: torch.randn(3) max_edge = config.data.max_edge
num_surf_points = config.model.num_surf_points
num_edge_points = config.model.num_edge_points
# 生成测试数据
test_data = {
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3),
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6),
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool),
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3),
'surf_pos': torch.randn(batch_size, max_face, 6),
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3),
'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点
}
brep_model = MockBRep() # 打印输入数据形状
query_points = torch.randn(batch_size, num_points, 3) logger.info("Input shapes:")
for name, tensor in test_data.items():
logger.info(f" {name}: {tensor.shape}")
# 前向传播 # 前向传播
sdf = model(brep_model, query_points) try:
if sdf is not None: sdf = model(**test_data)
print(f"Output SDF shape: {sdf.shape}") logger.info(f"\nOutput SDF shape: {sdf.shape}")
# 计算模型参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"\nModel statistics:")
logger.info(f" Total parameters: {total_params:,}")
logger.info(f" Trainable parameters: {trainable_params:,}")
except Exception as e:
logger.error(f"Error during forward pass: {str(e)}")
raise
if __name__ == "__main__":
main()

2
brep2sdf/train.py

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from brep2sdf.data.data import BRepSDFDataset from brep2sdf.data.data import BRepSDFDataset
from brep2sdf.networks.encoder import BRepToSDF, sdf_loss from brep2sdf.networks.network import BRepToSDF, sdf_loss
from brep2sdf.utils.logger import logger from brep2sdf.utils.logger import logger
from brep2sdf.config.default_config import get_default_config, load_config from brep2sdf.config.default_config import get_default_config, load_config
import wandb import wandb

Loading…
Cancel
Save