You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
383 lines
12 KiB
383 lines
12 KiB
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.utils import BaseOutput, is_torch_version
|
|
from diffusers.utils.accelerate_utils import apply_forward_hook
|
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
|
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
|
|
from diffusers.models.attention_processor import SpatialNorm
|
|
|
|
|
|
|
|
|
|
class Decoder1D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
out_channels=3,
|
|
up_block_types=("UpDecoderBlock2D",),
|
|
block_out_channels=(64,),
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
act_fn="silu",
|
|
norm_type="group", # group, spatial
|
|
):
|
|
'''
|
|
这是第一阶段的解码器,用于处理B-rep特征
|
|
包含三个主要部分:
|
|
conv_in: 输入卷积层,处理初始特征
|
|
mid_block: 中间处理块
|
|
up_blocks: 上采样块序列
|
|
支持梯度检查点功能(gradient checkpointing)以节省内存
|
|
输出维度: [B, C, L]
|
|
# NOTE:
|
|
1. 移除了分片(slicing)和平铺(tiling)功能
|
|
2. 直接使用mode()而不是sample()获取潜在向量
|
|
3. 简化了编码过程,只保留核心功能
|
|
4. 返回确定性的潜在向量而不是分布
|
|
'''
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
|
|
self.conv_in = nn.Conv1d(
|
|
in_channels,
|
|
block_out_channels[-1],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
|
|
self.mid_block = None
|
|
self.up_blocks = nn.ModuleList([])
|
|
|
|
temb_channels = in_channels if norm_type == "spatial" else None
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlock1D(
|
|
in_channels=block_out_channels[-1],
|
|
mid_channels=block_out_channels[-1],
|
|
)
|
|
|
|
# up
|
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
output_channel = reversed_block_out_channels[0]
|
|
for i, up_block_type in enumerate(up_block_types):
|
|
prev_output_channel = output_channel
|
|
output_channel = reversed_block_out_channels[i]
|
|
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
up_block = UpBlock1D(
|
|
in_channels=prev_output_channel,
|
|
out_channels=output_channel,
|
|
)
|
|
self.up_blocks.append(up_block)
|
|
prev_output_channel = output_channel
|
|
|
|
# out
|
|
if norm_type == "spatial":
|
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
|
else:
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
|
def forward(self, z, latent_embeds=None):
|
|
sample = z
|
|
sample = self.conv_in(sample)
|
|
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
if is_torch_version(">=", "1.11.0"):
|
|
# middle
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
|
)
|
|
# sample = sample.to(upscale_dtype)
|
|
|
|
# up
|
|
for up_block in self.up_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
|
)
|
|
else:
|
|
# middle
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block), sample, latent_embeds
|
|
)
|
|
# sample = sample.to(upscale_dtype)
|
|
|
|
# up
|
|
for up_block in self.up_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
|
else:
|
|
# middle
|
|
sample = self.mid_block(sample, latent_embeds)
|
|
# sample = sample.to(upscale_dtype)
|
|
# up
|
|
for up_block in self.up_blocks:
|
|
sample = up_block(sample, latent_embeds)
|
|
|
|
# post-process
|
|
if latent_embeds is None:
|
|
sample = self.conv_norm_out(sample)
|
|
else:
|
|
sample = self.conv_norm_out(sample, latent_embeds)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return sample
|
|
|
|
|
|
|
|
class SDFDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
latent_size,
|
|
dims,
|
|
dropout=None,
|
|
dropout_prob=0.0,
|
|
norm_layers=(),
|
|
latent_in=(),
|
|
weight_norm=False,
|
|
xyz_in_all=None,
|
|
use_tanh=False,
|
|
latent_dropout=False,
|
|
):
|
|
'''
|
|
这是第二阶段的解码器,用于生成SDF值
|
|
使用多层MLP结构
|
|
特点:
|
|
支持在不同层注入latent信息(通过latent_in参数)
|
|
可以在每层添加xyz坐标(通过xyz_in_all参数)
|
|
支持权重归一化和dropout
|
|
输入维度: [N, latent_size+3]
|
|
输出维度: [N, 1]
|
|
|
|
'''
|
|
super(SDFDecoder, self).__init__()
|
|
|
|
def make_sequence():
|
|
return []
|
|
|
|
dims = [latent_size + 3] + dims + [1]
|
|
|
|
self.num_layers = len(dims)
|
|
self.norm_layers = norm_layers
|
|
self.latent_in = latent_in
|
|
self.latent_dropout = latent_dropout
|
|
if self.latent_dropout:
|
|
self.lat_dp = nn.Dropout(0.2)
|
|
|
|
self.xyz_in_all = xyz_in_all
|
|
self.weight_norm = weight_norm
|
|
|
|
for layer in range(0, self.num_layers - 1):
|
|
if layer + 1 in latent_in:
|
|
out_dim = dims[layer + 1] - dims[0]
|
|
else:
|
|
out_dim = dims[layer + 1]
|
|
if self.xyz_in_all and layer != self.num_layers - 2:
|
|
out_dim -= 3
|
|
|
|
if weight_norm and layer in self.norm_layers:
|
|
setattr(
|
|
self,
|
|
"lin" + str(layer),
|
|
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
|
|
)
|
|
else:
|
|
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
|
|
|
|
if (
|
|
(not weight_norm)
|
|
and self.norm_layers is not None
|
|
and layer in self.norm_layers
|
|
):
|
|
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
|
|
|
|
self.use_tanh = use_tanh
|
|
if use_tanh:
|
|
self.tanh = nn.Tanh()
|
|
self.relu = nn.ReLU()
|
|
|
|
self.dropout_prob = dropout_prob
|
|
self.dropout = dropout
|
|
self.th = nn.Tanh()
|
|
|
|
# input: N x (L+3)
|
|
def forward(self, input):
|
|
xyz = input[:, -3:]
|
|
|
|
if input.shape[1] > 3 and self.latent_dropout:
|
|
latent_vecs = input[:, :-3]
|
|
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
|
|
x = torch.cat([latent_vecs, xyz], 1)
|
|
else:
|
|
x = input
|
|
|
|
for layer in range(0, self.num_layers - 1):
|
|
lin = getattr(self, "lin" + str(layer))
|
|
if layer in self.latent_in:
|
|
x = torch.cat([x, input], 1)
|
|
elif layer != 0 and self.xyz_in_all:
|
|
x = torch.cat([x, xyz], 1)
|
|
x = lin(x)
|
|
# last layer Tanh
|
|
if layer == self.num_layers - 2 and self.use_tanh:
|
|
x = self.tanh(x)
|
|
if layer < self.num_layers - 2:
|
|
if (
|
|
self.norm_layers is not None
|
|
and layer in self.norm_layers
|
|
and not self.weight_norm
|
|
):
|
|
bn = getattr(self, "bn" + str(layer))
|
|
x = bn(x)
|
|
x = self.relu(x)
|
|
if self.dropout is not None and layer in self.dropout:
|
|
x = F.dropout(x, p=self.dropout_prob, training=self.training)
|
|
|
|
if hasattr(self, "th"):
|
|
x = self.th(x)
|
|
|
|
return x
|
|
|
|
|
|
class BRep2SdfDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
latent_size=256,
|
|
feature_dims=[512, 512, 256, 128], # 特征解码器维度
|
|
sdf_dims=[512, 512, 512, 512], # SDF解码器维度
|
|
up_block_types=("UpDecoderBlock2D",),
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
norm_type="group",
|
|
dropout=None,
|
|
dropout_prob=0.0,
|
|
norm_layers=(),
|
|
latent_in=(),
|
|
weight_norm=False,
|
|
xyz_in_all=True,
|
|
use_tanh=True,
|
|
):
|
|
super().__init__()
|
|
|
|
# 1. 特征解码器 (使用Decoder1D结构)
|
|
self.feature_decoder = Decoder1D(
|
|
in_channels=latent_size,
|
|
out_channels=feature_dims[-1],
|
|
up_block_types=up_block_types,
|
|
block_out_channels=feature_dims,
|
|
layers_per_block=layers_per_block,
|
|
norm_num_groups=norm_num_groups,
|
|
norm_type=norm_type
|
|
)
|
|
|
|
# 2. SDF解码器 (使用原始Decoder结构)
|
|
self.sdf_decoder = SDFDecoder(
|
|
latent_size=feature_dims[-1], # 使用特征解码器的输出维度
|
|
dims=sdf_dims,
|
|
dropout=dropout,
|
|
dropout_prob=dropout_prob,
|
|
norm_layers=norm_layers,
|
|
latent_in=latent_in,
|
|
weight_norm=weight_norm,
|
|
xyz_in_all=xyz_in_all,
|
|
use_tanh=use_tanh,
|
|
latent_dropout=False
|
|
)
|
|
|
|
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式)
|
|
self.feature_transform = nn.Sequential(
|
|
nn.Linear(feature_dims[-1], feature_dims[-1]),
|
|
nn.LayerNorm(feature_dims[-1]),
|
|
nn.SiLU()
|
|
)
|
|
|
|
def forward(self, latent, query_points, latent_embeds=None):
|
|
"""
|
|
Args:
|
|
latent: [B, C, L] B-rep特征
|
|
query_points: [B, N, 3] 查询点
|
|
latent_embeds: 可选的条件嵌入
|
|
Returns:
|
|
sdf: [B, N, 1] SDF值
|
|
"""
|
|
# 1. 特征解码
|
|
features = self.feature_decoder(latent, latent_embeds) # [B, C, L]
|
|
|
|
# 2. 特征转换
|
|
B, C, L = features.shape
|
|
features = features.permute(0, 2, 1) # [B, L, C]
|
|
features = self.feature_transform(features) # [B, L, C]
|
|
|
|
# 3. 准备SDF解码器输入
|
|
_, N, _ = query_points.shape
|
|
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C]
|
|
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3]
|
|
|
|
# 4. 合并特征和坐标
|
|
sdf_input = torch.cat([
|
|
features.reshape(B*N*L, -1), # [B*N*L, C]
|
|
query_points.reshape(B*N*L, -1) # [B*N*L, 3]
|
|
], dim=-1)
|
|
|
|
# 5. SDF生成
|
|
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1]
|
|
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1]
|
|
|
|
# 6. 聚合多尺度SDF
|
|
sdf = sdf.mean(dim=2) # [B, N, 1]
|
|
|
|
return sdf
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
# 创建模型
|
|
decoder = BRepDecoder(
|
|
latent_size=256,
|
|
feature_dims=[512, 256, 128, 64],
|
|
sdf_dims=[512, 512, 512, 512],
|
|
up_block_types=("UpDecoderBlock2D",),
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
dropout=None,
|
|
dropout_prob=0.0,
|
|
norm_layers=[0, 1, 2, 3],
|
|
latent_in=[4],
|
|
weight_norm=True,
|
|
xyz_in_all=True,
|
|
use_tanh=True
|
|
)
|
|
|
|
# 测试数据
|
|
batch_size = 4
|
|
seq_len = 32
|
|
num_points = 1000
|
|
|
|
latent = torch.randn(batch_size, 256, seq_len)
|
|
query_points = torch.randn(batch_size, num_points, 3)
|
|
latent_embeds = torch.randn(batch_size, 256)
|
|
|
|
# 前向传播
|
|
sdf = decoder(latent, query_points, latent_embeds)
|
|
print(f"Input latent shape: {latent.shape}")
|
|
print(f"Query points shape: {query_points.shape}")
|
|
print(f"Output SDF shape: {sdf.shape}")
|
|
|