3 changed files with 217 additions and 483 deletions
@ -1,383 +1,42 @@ |
|||||
import math |
|
||||
import torch |
import torch |
||||
import torch.nn as nn |
import torch.nn as nn |
||||
from dataclasses import dataclass |
import torch.nn.functional as F |
||||
from typing import Dict, Optional, Tuple, Union |
|
||||
|
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
||||
from diffusers.utils import BaseOutput, is_torch_version |
|
||||
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor |
|
||||
from diffusers.models.modeling_utils import ModelMixin |
|
||||
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
|
||||
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d |
|
||||
from diffusers.models.attention_processor import SpatialNorm |
|
||||
|
|
||||
|
|
||||
|
|
||||
|
from typing import Dict, Optional, Tuple, Union |
||||
|
from brep2sdf.config.default_config import get_default_config |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
|
||||
class Decoder1D(nn.Module): |
class SDFHead(nn.Module): |
||||
def __init__( |
"""SDF预测头""" |
||||
self, |
def __init__(self, embed_dim: int = 768*2): |
||||
in_channels=3, |
|
||||
out_channels=3, |
|
||||
up_block_types=("UpDecoderBlock2D",), |
|
||||
block_out_channels=(64,), |
|
||||
layers_per_block=2, |
|
||||
norm_num_groups=32, |
|
||||
act_fn="silu", |
|
||||
norm_type="group", # group, spatial |
|
||||
): |
|
||||
''' |
|
||||
这是第一阶段的解码器,用于处理B-rep特征 |
|
||||
包含三个主要部分: |
|
||||
conv_in: 输入卷积层,处理初始特征 |
|
||||
mid_block: 中间处理块 |
|
||||
up_blocks: 上采样块序列 |
|
||||
支持梯度检查点功能(gradient checkpointing)以节省内存 |
|
||||
输出维度: [B, C, L] |
|
||||
# NOTE: |
|
||||
1. 移除了分片(slicing)和平铺(tiling)功能 |
|
||||
2. 直接使用mode()而不是sample()获取潜在向量 |
|
||||
3. 简化了编码过程,只保留核心功能 |
|
||||
4. 返回确定性的潜在向量而不是分布 |
|
||||
''' |
|
||||
super().__init__() |
super().__init__() |
||||
self.layers_per_block = layers_per_block |
self.mlp = nn.Sequential( |
||||
|
nn.Linear(embed_dim, embed_dim//2), |
||||
self.conv_in = nn.Conv1d( |
nn.LayerNorm(embed_dim//2), |
||||
in_channels, |
nn.ReLU(), |
||||
block_out_channels[-1], |
nn.Linear(embed_dim//2, embed_dim//4), |
||||
kernel_size=3, |
nn.LayerNorm(embed_dim//4), |
||||
stride=1, |
nn.ReLU(), |
||||
padding=1, |
nn.Linear(embed_dim//4, 1), |
||||
) |
nn.Tanh() |
||||
|
) |
||||
self.mid_block = None |
|
||||
self.up_blocks = nn.ModuleList([]) |
def forward(self, x): |
||||
|
return self.mlp(x) |
||||
temb_channels = in_channels if norm_type == "spatial" else None |
|
||||
|
class SDFTransformer(nn.Module): |
||||
# mid |
"""SDF Transformer编码器""" |
||||
self.mid_block = UNetMidBlock1D( |
def __init__(self, embed_dim: int = 768, num_layers: int = 6): |
||||
in_channels=block_out_channels[-1], |
|
||||
mid_channels=block_out_channels[-1], |
|
||||
) |
|
||||
|
|
||||
# up |
|
||||
reversed_block_out_channels = list(reversed(block_out_channels)) |
|
||||
output_channel = reversed_block_out_channels[0] |
|
||||
for i, up_block_type in enumerate(up_block_types): |
|
||||
prev_output_channel = output_channel |
|
||||
output_channel = reversed_block_out_channels[i] |
|
||||
|
|
||||
is_final_block = i == len(block_out_channels) - 1 |
|
||||
|
|
||||
up_block = UpBlock1D( |
|
||||
in_channels=prev_output_channel, |
|
||||
out_channels=output_channel, |
|
||||
) |
|
||||
self.up_blocks.append(up_block) |
|
||||
prev_output_channel = output_channel |
|
||||
|
|
||||
# out |
|
||||
if norm_type == "spatial": |
|
||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) |
|
||||
else: |
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) |
|
||||
self.conv_act = nn.SiLU() |
|
||||
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) |
|
||||
|
|
||||
self.gradient_checkpointing = False |
|
||||
|
|
||||
|
|
||||
def forward(self, z, latent_embeds=None): |
|
||||
sample = z |
|
||||
sample = self.conv_in(sample) |
|
||||
|
|
||||
if self.training and self.gradient_checkpointing: |
|
||||
|
|
||||
def create_custom_forward(module): |
|
||||
def custom_forward(*inputs): |
|
||||
return module(*inputs) |
|
||||
|
|
||||
return custom_forward |
|
||||
|
|
||||
if is_torch_version(">=", "1.11.0"): |
|
||||
# middle |
|
||||
sample = torch.utils.checkpoint.checkpoint( |
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False |
|
||||
) |
|
||||
# sample = sample.to(upscale_dtype) |
|
||||
|
|
||||
# up |
|
||||
for up_block in self.up_blocks: |
|
||||
sample = torch.utils.checkpoint.checkpoint( |
|
||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False |
|
||||
) |
|
||||
else: |
|
||||
# middle |
|
||||
sample = torch.utils.checkpoint.checkpoint( |
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds |
|
||||
) |
|
||||
# sample = sample.to(upscale_dtype) |
|
||||
|
|
||||
# up |
|
||||
for up_block in self.up_blocks: |
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) |
|
||||
else: |
|
||||
# middle |
|
||||
sample = self.mid_block(sample, latent_embeds) |
|
||||
# sample = sample.to(upscale_dtype) |
|
||||
# up |
|
||||
for up_block in self.up_blocks: |
|
||||
sample = up_block(sample, latent_embeds) |
|
||||
|
|
||||
# post-process |
|
||||
if latent_embeds is None: |
|
||||
sample = self.conv_norm_out(sample) |
|
||||
else: |
|
||||
sample = self.conv_norm_out(sample, latent_embeds) |
|
||||
sample = self.conv_act(sample) |
|
||||
sample = self.conv_out(sample) |
|
||||
|
|
||||
return sample |
|
||||
|
|
||||
|
|
||||
|
|
||||
class SDFDecoder(nn.Module): |
|
||||
def __init__( |
|
||||
self, |
|
||||
latent_size, |
|
||||
dims, |
|
||||
dropout=None, |
|
||||
dropout_prob=0.0, |
|
||||
norm_layers=(), |
|
||||
latent_in=(), |
|
||||
weight_norm=False, |
|
||||
xyz_in_all=None, |
|
||||
use_tanh=False, |
|
||||
latent_dropout=False, |
|
||||
): |
|
||||
''' |
|
||||
这是第二阶段的解码器,用于生成SDF值 |
|
||||
使用多层MLP结构 |
|
||||
特点: |
|
||||
支持在不同层注入latent信息(通过latent_in参数) |
|
||||
可以在每层添加xyz坐标(通过xyz_in_all参数) |
|
||||
支持权重归一化和dropout |
|
||||
输入维度: [N, latent_size+3] |
|
||||
输出维度: [N, 1] |
|
||||
|
|
||||
''' |
|
||||
super(SDFDecoder, self).__init__() |
|
||||
|
|
||||
def make_sequence(): |
|
||||
return [] |
|
||||
|
|
||||
dims = [latent_size + 3] + dims + [1] |
|
||||
|
|
||||
self.num_layers = len(dims) |
|
||||
self.norm_layers = norm_layers |
|
||||
self.latent_in = latent_in |
|
||||
self.latent_dropout = latent_dropout |
|
||||
if self.latent_dropout: |
|
||||
self.lat_dp = nn.Dropout(0.2) |
|
||||
|
|
||||
self.xyz_in_all = xyz_in_all |
|
||||
self.weight_norm = weight_norm |
|
||||
|
|
||||
for layer in range(0, self.num_layers - 1): |
|
||||
if layer + 1 in latent_in: |
|
||||
out_dim = dims[layer + 1] - dims[0] |
|
||||
else: |
|
||||
out_dim = dims[layer + 1] |
|
||||
if self.xyz_in_all and layer != self.num_layers - 2: |
|
||||
out_dim -= 3 |
|
||||
|
|
||||
if weight_norm and layer in self.norm_layers: |
|
||||
setattr( |
|
||||
self, |
|
||||
"lin" + str(layer), |
|
||||
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), |
|
||||
) |
|
||||
else: |
|
||||
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) |
|
||||
|
|
||||
if ( |
|
||||
(not weight_norm) |
|
||||
and self.norm_layers is not None |
|
||||
and layer in self.norm_layers |
|
||||
): |
|
||||
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) |
|
||||
|
|
||||
self.use_tanh = use_tanh |
|
||||
if use_tanh: |
|
||||
self.tanh = nn.Tanh() |
|
||||
self.relu = nn.ReLU() |
|
||||
|
|
||||
self.dropout_prob = dropout_prob |
|
||||
self.dropout = dropout |
|
||||
self.th = nn.Tanh() |
|
||||
|
|
||||
# input: N x (L+3) |
|
||||
def forward(self, input): |
|
||||
xyz = input[:, -3:] |
|
||||
|
|
||||
if input.shape[1] > 3 and self.latent_dropout: |
|
||||
latent_vecs = input[:, :-3] |
|
||||
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) |
|
||||
x = torch.cat([latent_vecs, xyz], 1) |
|
||||
else: |
|
||||
x = input |
|
||||
|
|
||||
for layer in range(0, self.num_layers - 1): |
|
||||
lin = getattr(self, "lin" + str(layer)) |
|
||||
if layer in self.latent_in: |
|
||||
x = torch.cat([x, input], 1) |
|
||||
elif layer != 0 and self.xyz_in_all: |
|
||||
x = torch.cat([x, xyz], 1) |
|
||||
x = lin(x) |
|
||||
# last layer Tanh |
|
||||
if layer == self.num_layers - 2 and self.use_tanh: |
|
||||
x = self.tanh(x) |
|
||||
if layer < self.num_layers - 2: |
|
||||
if ( |
|
||||
self.norm_layers is not None |
|
||||
and layer in self.norm_layers |
|
||||
and not self.weight_norm |
|
||||
): |
|
||||
bn = getattr(self, "bn" + str(layer)) |
|
||||
x = bn(x) |
|
||||
x = self.relu(x) |
|
||||
if self.dropout is not None and layer in self.dropout: |
|
||||
x = F.dropout(x, p=self.dropout_prob, training=self.training) |
|
||||
|
|
||||
if hasattr(self, "th"): |
|
||||
x = self.th(x) |
|
||||
|
|
||||
return x |
|
||||
|
|
||||
|
|
||||
class BRep2SdfDecoder(nn.Module): |
|
||||
def __init__( |
|
||||
self, |
|
||||
latent_size=256, |
|
||||
feature_dims=[512, 512, 256, 128], # 特征解码器维度 |
|
||||
sdf_dims=[512, 512, 512, 512], # SDF解码器维度 |
|
||||
up_block_types=("UpDecoderBlock2D",), |
|
||||
layers_per_block=2, |
|
||||
norm_num_groups=32, |
|
||||
norm_type="group", |
|
||||
dropout=None, |
|
||||
dropout_prob=0.0, |
|
||||
norm_layers=(), |
|
||||
latent_in=(), |
|
||||
weight_norm=False, |
|
||||
xyz_in_all=True, |
|
||||
use_tanh=True, |
|
||||
): |
|
||||
super().__init__() |
super().__init__() |
||||
|
layer = nn.TransformerEncoderLayer( |
||||
# 1. 特征解码器 (使用Decoder1D结构) |
d_model=embed_dim, |
||||
self.feature_decoder = Decoder1D( |
nhead=8, |
||||
in_channels=latent_size, |
dim_feedforward=1024, |
||||
out_channels=feature_dims[-1], |
dropout=0.1, |
||||
up_block_types=up_block_types, |
batch_first=True, |
||||
block_out_channels=feature_dims, |
norm_first=False # 修改这里:设置为False |
||||
layers_per_block=layers_per_block, |
) |
||||
norm_num_groups=norm_num_groups, |
self.transformer = nn.TransformerEncoder(layer, num_layers) |
||||
norm_type=norm_type |
|
||||
) |
def forward(self, x, mask=None): |
||||
|
return self.transformer(x, src_key_padding_mask=mask) |
||||
# 2. SDF解码器 (使用原始Decoder结构) |
|
||||
self.sdf_decoder = SDFDecoder( |
|
||||
latent_size=feature_dims[-1], # 使用特征解码器的输出维度 |
|
||||
dims=sdf_dims, |
|
||||
dropout=dropout, |
|
||||
dropout_prob=dropout_prob, |
|
||||
norm_layers=norm_layers, |
|
||||
latent_in=latent_in, |
|
||||
weight_norm=weight_norm, |
|
||||
xyz_in_all=xyz_in_all, |
|
||||
use_tanh=use_tanh, |
|
||||
latent_dropout=False |
|
||||
) |
|
||||
|
|
||||
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) |
|
||||
self.feature_transform = nn.Sequential( |
|
||||
nn.Linear(feature_dims[-1], feature_dims[-1]), |
|
||||
nn.LayerNorm(feature_dims[-1]), |
|
||||
nn.SiLU() |
|
||||
) |
|
||||
|
|
||||
def forward(self, latent, query_points, latent_embeds=None): |
|
||||
""" |
|
||||
Args: |
|
||||
latent: [B, C, L] B-rep特征 |
|
||||
query_points: [B, N, 3] 查询点 |
|
||||
latent_embeds: 可选的条件嵌入 |
|
||||
Returns: |
|
||||
sdf: [B, N, 1] SDF值 |
|
||||
""" |
|
||||
# 1. 特征解码 |
|
||||
features = self.feature_decoder(latent, latent_embeds) # [B, C, L] |
|
||||
|
|
||||
# 2. 特征转换 |
|
||||
B, C, L = features.shape |
|
||||
features = features.permute(0, 2, 1) # [B, L, C] |
|
||||
features = self.feature_transform(features) # [B, L, C] |
|
||||
|
|
||||
# 3. 准备SDF解码器输入 |
|
||||
_, N, _ = query_points.shape |
|
||||
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C] |
|
||||
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3] |
|
||||
|
|
||||
# 4. 合并特征和坐标 |
|
||||
sdf_input = torch.cat([ |
|
||||
features.reshape(B*N*L, -1), # [B*N*L, C] |
|
||||
query_points.reshape(B*N*L, -1) # [B*N*L, 3] |
|
||||
], dim=-1) |
|
||||
|
|
||||
# 5. SDF生成 |
|
||||
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1] |
|
||||
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1] |
|
||||
|
|
||||
# 6. 聚合多尺度SDF |
|
||||
sdf = sdf.mean(dim=2) # [B, N, 1] |
|
||||
|
|
||||
return sdf |
|
||||
|
|
||||
# 使用示例 |
|
||||
if __name__ == "__main__": |
|
||||
# 创建模型 |
|
||||
decoder = BRepDecoder( |
|
||||
latent_size=256, |
|
||||
feature_dims=[512, 256, 128, 64], |
|
||||
sdf_dims=[512, 512, 512, 512], |
|
||||
up_block_types=("UpDecoderBlock2D",), |
|
||||
layers_per_block=2, |
|
||||
norm_num_groups=32, |
|
||||
dropout=None, |
|
||||
dropout_prob=0.0, |
|
||||
norm_layers=[0, 1, 2, 3], |
|
||||
latent_in=[4], |
|
||||
weight_norm=True, |
|
||||
xyz_in_all=True, |
|
||||
use_tanh=True |
|
||||
) |
|
||||
|
|
||||
# 测试数据 |
|
||||
batch_size = 4 |
|
||||
seq_len = 32 |
|
||||
num_points = 1000 |
|
||||
|
|
||||
latent = torch.randn(batch_size, 256, seq_len) |
|
||||
query_points = torch.randn(batch_size, num_points, 3) |
|
||||
latent_embeds = torch.randn(batch_size, 256) |
|
||||
|
|
||||
# 前向传播 |
|
||||
sdf = decoder(latent, query_points, latent_embeds) |
|
||||
print(f"Input latent shape: {latent.shape}") |
|
||||
print(f"Query points shape: {query_points.shape}") |
|
||||
print(f"Output SDF shape: {sdf.shape}") |
|
||||
|
@ -1,127 +1,202 @@ |
|||||
import torch |
import torch |
||||
import torch.nn as nn |
import torch.nn as nn |
||||
from encoder import BRepEncoder |
import torch.nn.functional as F |
||||
from decoder import BRep2SdfDecoder |
from typing import Dict, Optional, Tuple, Union |
||||
|
from brep2sdf.config.default_config import get_default_config |
||||
class BRep2SDF(nn.Module): |
from brep2sdf.utils.logger import logger |
||||
def __init__( |
|
||||
self, |
from brep2sdf.networks.encoder import BRepFeatureEmbedder |
||||
# 编码器参数 |
from brep2sdf.networks.decoder import SDFHead, SDFTransformer |
||||
in_channels=3, |
|
||||
latent_size=256, |
|
||||
encoder_block_out_channels=(512, 256, 128, 64), |
class BRepToSDF(nn.Module): |
||||
# 解码器参数 |
def __init__(self, config=None): |
||||
decoder_feature_dims=(512, 256, 128, 64), |
|
||||
sdf_dims=(512, 512, 512, 512), |
|
||||
# 共享参数 |
|
||||
layers_per_block=2, |
|
||||
norm_num_groups=32, |
|
||||
# SDF特定参数 |
|
||||
dropout=None, |
|
||||
dropout_prob=0.0, |
|
||||
norm_layers=(0, 1, 2, 3), |
|
||||
latent_in=(4,), |
|
||||
weight_norm=True, |
|
||||
xyz_in_all=True, |
|
||||
use_tanh=True, |
|
||||
): |
|
||||
super().__init__() |
super().__init__() |
||||
|
# 获取配置 |
||||
|
if config is None: |
||||
|
self.config = get_default_config() |
||||
|
else: |
||||
|
self.config = config |
||||
|
|
||||
|
# 从配置中读取参数 |
||||
|
self.embed_dim = self.config.model.embed_dim |
||||
|
self.brep_feature_dim = self.config.model.brep_feature_dim |
||||
|
self.latent_dim = self.config.model.latent_dim |
||||
|
self.use_cf = self.config.model.use_cf |
||||
|
|
||||
|
# 1. 查询点编码器 |
||||
|
self.query_encoder = nn.Sequential( |
||||
|
nn.Linear(3, self.embed_dim//4), |
||||
|
nn.LayerNorm(self.embed_dim//4), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.embed_dim//4, self.embed_dim//2), |
||||
|
nn.LayerNorm(self.embed_dim//2), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.embed_dim//2, self.embed_dim) |
||||
|
) |
||||
|
|
||||
|
# 2. B-rep特征编码器 |
||||
|
self.brep_embedder = BRepFeatureEmbedder() |
||||
|
|
||||
|
# 3. 特征融合Transformer |
||||
|
self.transformer = SDFTransformer( |
||||
|
embed_dim=self.embed_dim, |
||||
|
num_layers=6 # 这个参数也可以移到配置文件中 |
||||
|
) |
||||
|
|
||||
|
# 4. SDF预测头 |
||||
|
self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) |
||||
|
|
||||
|
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): |
||||
|
"""B-rep到SDF的前向传播 |
||||
|
|
||||
|
Args: |
||||
|
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] |
||||
|
edge_pos: 边位置 [B, max_face, max_edge, 6] |
||||
|
edge_mask: 边掩码 [B, max_face, max_edge] |
||||
|
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] |
||||
|
surf_pos: 面位置 [B, max_face, 6] |
||||
|
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] |
||||
|
query_points: 查询点 [B, num_queries, 3] |
||||
|
data_class: (可选) 类别标签 |
||||
|
|
||||
|
Returns: |
||||
|
sdf: 预测的SDF值 [B, num_queries, 1] |
||||
|
""" |
||||
|
B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries |
||||
|
|
||||
|
try: |
||||
|
# 确保query_points需要梯度 |
||||
|
if not query_points.requires_grad: |
||||
|
query_points = query_points.detach().requires_grad_(True) |
||||
|
|
||||
# 1. 编码器配置 |
|
||||
encoder_config = type('Config', (), { |
|
||||
'in_channels': in_channels, |
|
||||
'out_channels': latent_size, |
|
||||
'block_out_channels': encoder_block_out_channels, |
|
||||
'layers_per_block': layers_per_block, |
|
||||
'norm_num_groups': norm_num_groups, |
|
||||
'encoder_params': { |
|
||||
'in_channels': in_channels, |
|
||||
'out_channels': latent_size, |
|
||||
'block_out_channels': encoder_block_out_channels, |
|
||||
'layers_per_block': layers_per_block, |
|
||||
'norm_num_groups': norm_num_groups, |
|
||||
} |
|
||||
})() |
|
||||
|
|
||||
# 2. 解码器配置 |
|
||||
decoder_config = { |
|
||||
'latent_size': latent_size, |
|
||||
'feature_dims': decoder_feature_dims, |
|
||||
'sdf_dims': sdf_dims, |
|
||||
'layers_per_block': layers_per_block, |
|
||||
'norm_num_groups': norm_num_groups, |
|
||||
'dropout': dropout, |
|
||||
'dropout_prob': dropout_prob, |
|
||||
'norm_layers': norm_layers, |
|
||||
'latent_in': latent_in, |
|
||||
'weight_norm': weight_norm, |
|
||||
'xyz_in_all': xyz_in_all, |
|
||||
'use_tanh': use_tanh, |
|
||||
} |
|
||||
|
|
||||
# 3. 创建编码器和解码器 |
# 1. B-rep特征编码 |
||||
self.encoder = BRepEncoder(encoder_config) |
brep_features = self.brep_embedder( |
||||
self.decoder = BRep2SdfDecoder(**decoder_config) |
edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] |
||||
|
edge_pos=edge_pos, # [B, max_face, max_edge, 6] |
||||
|
edge_mask=edge_mask, # [B, max_face, max_edge] |
||||
|
surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3] |
||||
|
surf_pos=surf_pos, # [B, max_face, 6] |
||||
|
vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3] |
||||
|
data_class=data_class |
||||
|
) # [B, max_face*(max_edge+1), embed_dim] |
||||
|
|
||||
def encode(self, brep_model): |
# 2. 查询点编码 |
||||
"""编码B-rep模型为潜在特征""" |
query_features = self.query_encoder(query_points) # [B, Q, embed_dim] |
||||
return self.encoder.encode(brep_model) |
|
||||
|
|
||||
def decode(self, latent, query_points, latent_embeds=None): |
# 3. 提取全局特征 |
||||
"""从潜在特征解码SDF值""" |
global_features = brep_features.mean(dim=1) # [B, embed_dim] |
||||
return self.decoder(latent, query_points, latent_embeds) |
|
||||
|
# 4. 为每个查询点准备特征 |
||||
|
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] |
||||
|
|
||||
|
# 5. 连接查询点特征和全局特征 |
||||
|
combined_features = torch.cat([ |
||||
|
expanded_features, # [B, Q, embed_dim] |
||||
|
query_features # [B, Q, embed_dim] |
||||
|
], dim=-1) # [B, Q, embed_dim*2] |
||||
|
|
||||
|
# 6. SDF预测 |
||||
|
sdf = self.sdf_head(combined_features) # [B, Q, 1] |
||||
|
|
||||
|
if not sdf.requires_grad: |
||||
|
logger.warning("SDF output does not require grad!") |
||||
|
|
||||
def forward(self, brep_model, query_points): |
|
||||
"""完整的前向传播过程""" |
|
||||
# 1. 编码B-rep模型 |
|
||||
latent = self.encode(brep_model) |
|
||||
if latent is None: |
|
||||
return None |
|
||||
|
|
||||
# 2. 解码SDF值 |
|
||||
sdf = self.decode(latent, query_points) |
|
||||
return sdf |
return sdf |
||||
|
|
||||
# 使用示例 |
except Exception as e: |
||||
if __name__ == "__main__": |
logger.error(f"Error in BRepToSDF forward pass:") |
||||
# 创建模型 |
logger.error(f" Error message: {str(e)}") |
||||
model = BRep2SDF( |
logger.error(f" Input shapes:") |
||||
in_channels=3, |
logger.error(f" edge_ncs: {edge_ncs.shape}") |
||||
latent_size=256, |
logger.error(f" edge_pos: {edge_pos.shape}") |
||||
encoder_block_out_channels=(512, 256, 128, 64), |
logger.error(f" edge_mask: {edge_mask.shape}") |
||||
decoder_feature_dims=(512, 256, 128, 64), |
logger.error(f" surf_ncs: {surf_ncs.shape}") |
||||
sdf_dims=(512, 512, 512, 512), |
logger.error(f" surf_pos: {surf_pos.shape}") |
||||
layers_per_block=2, |
logger.error(f" vertex_pos: {vertex_pos.shape}") |
||||
norm_num_groups=32, |
logger.error(f" query_points: {query_points.shape}") |
||||
|
raise |
||||
|
|
||||
|
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
||||
|
"""SDF损失函数""" |
||||
|
# 确保points需要梯度 |
||||
|
if not points.requires_grad: |
||||
|
points = points.detach().requires_grad_(True) |
||||
|
|
||||
|
# L1损失 |
||||
|
l1_loss = F.l1_loss(pred_sdf, gt_sdf) |
||||
|
|
||||
|
try: |
||||
|
# 梯度约束损失 |
||||
|
grad = torch.autograd.grad( |
||||
|
pred_sdf.sum(), |
||||
|
points, |
||||
|
create_graph=True, |
||||
|
retain_graph=True, |
||||
|
allow_unused=True |
||||
|
)[0] |
||||
|
|
||||
|
if grad is not None: |
||||
|
grad_constraint = F.mse_loss( |
||||
|
torch.norm(grad, dim=-1), |
||||
|
torch.ones_like(pred_sdf.squeeze(-1)) |
||||
) |
) |
||||
|
else: |
||||
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
||||
|
|
||||
|
except Exception as e: |
||||
|
logger.warning(f"Gradient computation failed: {str(e)}") |
||||
|
grad_constraint = torch.tensor(0.0, device=pred_sdf.device) |
||||
|
|
||||
# 测试数据 |
return l1_loss + grad_weight * grad_constraint |
||||
batch_size = 4 |
|
||||
seq_len = 32 |
|
||||
num_points = 1000 |
|
||||
|
|
||||
# 模拟B-rep模型数据 |
def main(): |
||||
class MockBRep: |
# 获取配置 |
||||
def __init__(self): |
config = get_default_config() |
||||
self.faces = [MockFace() for _ in range(10)] |
|
||||
self.edges = [MockEdge() for _ in range(20)] |
|
||||
|
|
||||
class MockFace: |
# 初始化模型 |
||||
def __init__(self): |
model = BRepToSDF(config=config) |
||||
self.center_point = torch.randn(3) |
|
||||
self.normal_vector = torch.randn(3) |
|
||||
self.surface_type = 0 |
|
||||
self.edges = [] |
|
||||
|
|
||||
class MockEdge: |
# 从配置获取参数 |
||||
def __init__(self): |
batch_size = config.train.batch_size |
||||
self.length = lambda: 1.0 |
max_face = config.data.max_face |
||||
self.point_at = lambda t: torch.randn(3) |
max_edge = config.data.max_edge |
||||
|
num_surf_points = config.model.num_surf_points |
||||
|
num_edge_points = config.model.num_edge_points |
||||
|
|
||||
|
# 生成测试数据 |
||||
|
test_data = { |
||||
|
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), |
||||
|
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), |
||||
|
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), |
||||
|
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), |
||||
|
'surf_pos': torch.randn(batch_size, max_face, 6), |
||||
|
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), |
||||
|
'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 |
||||
|
} |
||||
|
|
||||
brep_model = MockBRep() |
# 打印输入数据形状 |
||||
query_points = torch.randn(batch_size, num_points, 3) |
logger.info("Input shapes:") |
||||
|
for name, tensor in test_data.items(): |
||||
|
logger.info(f" {name}: {tensor.shape}") |
||||
|
|
||||
# 前向传播 |
# 前向传播 |
||||
sdf = model(brep_model, query_points) |
try: |
||||
if sdf is not None: |
sdf = model(**test_data) |
||||
print(f"Output SDF shape: {sdf.shape}") |
logger.info(f"\nOutput SDF shape: {sdf.shape}") |
||||
|
|
||||
|
# 计算模型参数量 |
||||
|
total_params = sum(p.numel() for p in model.parameters()) |
||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
||||
|
logger.info(f"\nModel statistics:") |
||||
|
logger.info(f" Total parameters: {total_params:,}") |
||||
|
logger.info(f" Trainable parameters: {trainable_params:,}") |
||||
|
|
||||
|
except Exception as e: |
||||
|
logger.error(f"Error during forward pass: {str(e)}") |
||||
|
raise |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
Loading…
Reference in new issue