You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

383 lines
12 KiB

import math
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
class Decoder1D(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
):
'''
这是第一阶段的解码器,用于处理B-rep特征
包含三个主要部分:
conv_in: 输入卷积层,处理初始特征
mid_block: 中间处理块
up_blocks: 上采样块序列
支持梯度检查点功能(gradient checkpointing)以节省内存
输出维度: [B, C, L]
# NOTE:
1. 移除了分片(slicing)和平铺(tiling)功能
2. 直接使用mode()而不是sample()获取潜在向量
3. 简化了编码过程,只保留核心功能
4. 返回确定性的潜在向量而不是分布
'''
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv1d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlock1D(
in_channels=prev_output_channel,
out_channels=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None):
sample = z
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class SDFDecoder(nn.Module):
def __init__(
self,
latent_size,
dims,
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=None,
use_tanh=False,
latent_dropout=False,
):
'''
这是第二阶段的解码器,用于生成SDF值
使用多层MLP结构
特点:
支持在不同层注入latent信息(通过latent_in参数)
可以在每层添加xyz坐标(通过xyz_in_all参数)
支持权重归一化和dropout
输入维度: [N, latent_size+3]
输出维度: [N, 1]
'''
super(SDFDecoder, self).__init__()
def make_sequence():
return []
dims = [latent_size + 3] + dims + [1]
self.num_layers = len(dims)
self.norm_layers = norm_layers
self.latent_in = latent_in
self.latent_dropout = latent_dropout
if self.latent_dropout:
self.lat_dp = nn.Dropout(0.2)
self.xyz_in_all = xyz_in_all
self.weight_norm = weight_norm
for layer in range(0, self.num_layers - 1):
if layer + 1 in latent_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
if self.xyz_in_all and layer != self.num_layers - 2:
out_dim -= 3
if weight_norm and layer in self.norm_layers:
setattr(
self,
"lin" + str(layer),
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
)
else:
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
if (
(not weight_norm)
and self.norm_layers is not None
and layer in self.norm_layers
):
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
self.use_tanh = use_tanh
if use_tanh:
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.dropout_prob = dropout_prob
self.dropout = dropout
self.th = nn.Tanh()
# input: N x (L+3)
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x
class BRep2SdfDecoder(nn.Module):
def __init__(
self,
latent_size=256,
feature_dims=[512, 512, 256, 128], # 特征解码器维度
sdf_dims=[512, 512, 512, 512], # SDF解码器维度
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
norm_type="group",
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=True,
use_tanh=True,
):
super().__init__()
# 1. 特征解码器 (使用Decoder1D结构)
self.feature_decoder = Decoder1D(
in_channels=latent_size,
out_channels=feature_dims[-1],
up_block_types=up_block_types,
block_out_channels=feature_dims,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
norm_type=norm_type
)
# 2. SDF解码器 (使用原始Decoder结构)
self.sdf_decoder = SDFDecoder(
latent_size=feature_dims[-1], # 使用特征解码器的输出维度
dims=sdf_dims,
dropout=dropout,
dropout_prob=dropout_prob,
norm_layers=norm_layers,
latent_in=latent_in,
weight_norm=weight_norm,
xyz_in_all=xyz_in_all,
use_tanh=use_tanh,
latent_dropout=False
)
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式)
self.feature_transform = nn.Sequential(
nn.Linear(feature_dims[-1], feature_dims[-1]),
nn.LayerNorm(feature_dims[-1]),
nn.SiLU()
)
def forward(self, latent, query_points, latent_embeds=None):
"""
Args:
latent: [B, C, L] B-rep特征
query_points: [B, N, 3] 查询点
latent_embeds: 可选的条件嵌入
Returns:
sdf: [B, N, 1] SDF值
"""
# 1. 特征解码
features = self.feature_decoder(latent, latent_embeds) # [B, C, L]
# 2. 特征转换
B, C, L = features.shape
features = features.permute(0, 2, 1) # [B, L, C]
features = self.feature_transform(features) # [B, L, C]
# 3. 准备SDF解码器输入
_, N, _ = query_points.shape
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C]
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3]
# 4. 合并特征和坐标
sdf_input = torch.cat([
features.reshape(B*N*L, -1), # [B*N*L, C]
query_points.reshape(B*N*L, -1) # [B*N*L, 3]
], dim=-1)
# 5. SDF生成
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1]
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1]
# 6. 聚合多尺度SDF
sdf = sdf.mean(dim=2) # [B, N, 1]
return sdf
# 使用示例
if __name__ == "__main__":
# 创建模型
decoder = BRepDecoder(
latent_size=256,
feature_dims=[512, 256, 128, 64],
sdf_dims=[512, 512, 512, 512],
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
dropout=None,
dropout_prob=0.0,
norm_layers=[0, 1, 2, 3],
latent_in=[4],
weight_norm=True,
xyz_in_all=True,
use_tanh=True
)
# 测试数据
batch_size = 4
seq_len = 32
num_points = 1000
latent = torch.randn(batch_size, 256, seq_len)
query_points = torch.randn(batch_size, num_points, 3)
latent_embeds = torch.randn(batch_size, 256)
# 前向传播
sdf = decoder(latent, query_points, latent_embeds)
print(f"Input latent shape: {latent.shape}")
print(f"Query points shape: {query_points.shape}")
print(f"Output SDF shape: {sdf.shape}")