import math import torch import torch.nn as nn from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.modeling_utils import ModelMixin from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d from diffusers.models.attention_processor import SpatialNorm class Decoder1D(nn.Module): def __init__( self, in_channels=3, out_channels=3, up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", norm_type="group", # group, spatial ): ''' 这是第一阶段的解码器,用于处理B-rep特征 包含三个主要部分: conv_in: 输入卷积层,处理初始特征 mid_block: 中间处理块 up_blocks: 上采样块序列 支持梯度检查点功能(gradient checkpointing)以节省内存 输出维度: [B, C, L] # NOTE: 1. 移除了分片(slicing)和平铺(tiling)功能 2. 直接使用mode()而不是sample()获取潜在向量 3. 简化了编码过程,只保留核心功能 4. 返回确定性的潜在向量而不是分布 ''' super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv1d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock1D( in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = UpBlock1D( in_channels=prev_output_channel, out_channels=output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, z, latent_embeds=None): sample = z sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False ) # sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) # sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) # sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class SDFDecoder(nn.Module): def __init__( self, latent_size, dims, dropout=None, dropout_prob=0.0, norm_layers=(), latent_in=(), weight_norm=False, xyz_in_all=None, use_tanh=False, latent_dropout=False, ): ''' 这是第二阶段的解码器,用于生成SDF值 使用多层MLP结构 特点: 支持在不同层注入latent信息(通过latent_in参数) 可以在每层添加xyz坐标(通过xyz_in_all参数) 支持权重归一化和dropout 输入维度: [N, latent_size+3] 输出维度: [N, 1] ''' super(SDFDecoder, self).__init__() def make_sequence(): return [] dims = [latent_size + 3] + dims + [1] self.num_layers = len(dims) self.norm_layers = norm_layers self.latent_in = latent_in self.latent_dropout = latent_dropout if self.latent_dropout: self.lat_dp = nn.Dropout(0.2) self.xyz_in_all = xyz_in_all self.weight_norm = weight_norm for layer in range(0, self.num_layers - 1): if layer + 1 in latent_in: out_dim = dims[layer + 1] - dims[0] else: out_dim = dims[layer + 1] if self.xyz_in_all and layer != self.num_layers - 2: out_dim -= 3 if weight_norm and layer in self.norm_layers: setattr( self, "lin" + str(layer), nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), ) else: setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) if ( (not weight_norm) and self.norm_layers is not None and layer in self.norm_layers ): setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) self.use_tanh = use_tanh if use_tanh: self.tanh = nn.Tanh() self.relu = nn.ReLU() self.dropout_prob = dropout_prob self.dropout = dropout self.th = nn.Tanh() # input: N x (L+3) def forward(self, input): xyz = input[:, -3:] if input.shape[1] > 3 and self.latent_dropout: latent_vecs = input[:, :-3] latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) x = torch.cat([latent_vecs, xyz], 1) else: x = input for layer in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(layer)) if layer in self.latent_in: x = torch.cat([x, input], 1) elif layer != 0 and self.xyz_in_all: x = torch.cat([x, xyz], 1) x = lin(x) # last layer Tanh if layer == self.num_layers - 2 and self.use_tanh: x = self.tanh(x) if layer < self.num_layers - 2: if ( self.norm_layers is not None and layer in self.norm_layers and not self.weight_norm ): bn = getattr(self, "bn" + str(layer)) x = bn(x) x = self.relu(x) if self.dropout is not None and layer in self.dropout: x = F.dropout(x, p=self.dropout_prob, training=self.training) if hasattr(self, "th"): x = self.th(x) return x class BRep2SdfDecoder(nn.Module): def __init__( self, latent_size=256, feature_dims=[512, 512, 256, 128], # 特征解码器维度 sdf_dims=[512, 512, 512, 512], # SDF解码器维度 up_block_types=("UpDecoderBlock2D",), layers_per_block=2, norm_num_groups=32, norm_type="group", dropout=None, dropout_prob=0.0, norm_layers=(), latent_in=(), weight_norm=False, xyz_in_all=True, use_tanh=True, ): super().__init__() # 1. 特征解码器 (使用Decoder1D结构) self.feature_decoder = Decoder1D( in_channels=latent_size, out_channels=feature_dims[-1], up_block_types=up_block_types, block_out_channels=feature_dims, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, norm_type=norm_type ) # 2. SDF解码器 (使用原始Decoder结构) self.sdf_decoder = SDFDecoder( latent_size=feature_dims[-1], # 使用特征解码器的输出维度 dims=sdf_dims, dropout=dropout, dropout_prob=dropout_prob, norm_layers=norm_layers, latent_in=latent_in, weight_norm=weight_norm, xyz_in_all=xyz_in_all, use_tanh=use_tanh, latent_dropout=False ) # 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) self.feature_transform = nn.Sequential( nn.Linear(feature_dims[-1], feature_dims[-1]), nn.LayerNorm(feature_dims[-1]), nn.SiLU() ) def forward(self, latent, query_points, latent_embeds=None): """ Args: latent: [B, C, L] B-rep特征 query_points: [B, N, 3] 查询点 latent_embeds: 可选的条件嵌入 Returns: sdf: [B, N, 1] SDF值 """ # 1. 特征解码 features = self.feature_decoder(latent, latent_embeds) # [B, C, L] # 2. 特征转换 B, C, L = features.shape features = features.permute(0, 2, 1) # [B, L, C] features = self.feature_transform(features) # [B, L, C] # 3. 准备SDF解码器输入 _, N, _ = query_points.shape features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C] query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3] # 4. 合并特征和坐标 sdf_input = torch.cat([ features.reshape(B*N*L, -1), # [B*N*L, C] query_points.reshape(B*N*L, -1) # [B*N*L, 3] ], dim=-1) # 5. SDF生成 sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1] sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1] # 6. 聚合多尺度SDF sdf = sdf.mean(dim=2) # [B, N, 1] return sdf # 使用示例 if __name__ == "__main__": # 创建模型 decoder = BRepDecoder( latent_size=256, feature_dims=[512, 256, 128, 64], sdf_dims=[512, 512, 512, 512], up_block_types=("UpDecoderBlock2D",), layers_per_block=2, norm_num_groups=32, dropout=None, dropout_prob=0.0, norm_layers=[0, 1, 2, 3], latent_in=[4], weight_norm=True, xyz_in_all=True, use_tanh=True ) # 测试数据 batch_size = 4 seq_len = 32 num_points = 1000 latent = torch.randn(batch_size, 256, seq_len) query_points = torch.randn(batch_size, num_points, 3) latent_embeds = torch.randn(batch_size, 256) # 前向传播 sdf = decoder(latent, query_points, latent_embeds) print(f"Input latent shape: {latent.shape}") print(f"Query points shape: {query_points.shape}") print(f"Output SDF shape: {sdf.shape}")