5 changed files with 866 additions and 109 deletions
@ -0,0 +1,383 @@ |
|||||
|
import math |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from dataclasses import dataclass |
||||
|
from typing import Dict, Optional, Tuple, Union |
||||
|
|
||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
||||
|
from diffusers.utils import BaseOutput, is_torch_version |
||||
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
||||
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor |
||||
|
from diffusers.models.modeling_utils import ModelMixin |
||||
|
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
||||
|
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d |
||||
|
from diffusers.models.attention_processor import SpatialNorm |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
class Decoder1D(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
in_channels=3, |
||||
|
out_channels=3, |
||||
|
up_block_types=("UpDecoderBlock2D",), |
||||
|
block_out_channels=(64,), |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
act_fn="silu", |
||||
|
norm_type="group", # group, spatial |
||||
|
): |
||||
|
''' |
||||
|
这是第一阶段的解码器,用于处理B-rep特征 |
||||
|
包含三个主要部分: |
||||
|
conv_in: 输入卷积层,处理初始特征 |
||||
|
mid_block: 中间处理块 |
||||
|
up_blocks: 上采样块序列 |
||||
|
支持梯度检查点功能(gradient checkpointing)以节省内存 |
||||
|
输出维度: [B, C, L] |
||||
|
# NOTE: |
||||
|
1. 移除了分片(slicing)和平铺(tiling)功能 |
||||
|
2. 直接使用mode()而不是sample()获取潜在向量 |
||||
|
3. 简化了编码过程,只保留核心功能 |
||||
|
4. 返回确定性的潜在向量而不是分布 |
||||
|
''' |
||||
|
super().__init__() |
||||
|
self.layers_per_block = layers_per_block |
||||
|
|
||||
|
self.conv_in = nn.Conv1d( |
||||
|
in_channels, |
||||
|
block_out_channels[-1], |
||||
|
kernel_size=3, |
||||
|
stride=1, |
||||
|
padding=1, |
||||
|
) |
||||
|
|
||||
|
self.mid_block = None |
||||
|
self.up_blocks = nn.ModuleList([]) |
||||
|
|
||||
|
temb_channels = in_channels if norm_type == "spatial" else None |
||||
|
|
||||
|
# mid |
||||
|
self.mid_block = UNetMidBlock1D( |
||||
|
in_channels=block_out_channels[-1], |
||||
|
mid_channels=block_out_channels[-1], |
||||
|
) |
||||
|
|
||||
|
# up |
||||
|
reversed_block_out_channels = list(reversed(block_out_channels)) |
||||
|
output_channel = reversed_block_out_channels[0] |
||||
|
for i, up_block_type in enumerate(up_block_types): |
||||
|
prev_output_channel = output_channel |
||||
|
output_channel = reversed_block_out_channels[i] |
||||
|
|
||||
|
is_final_block = i == len(block_out_channels) - 1 |
||||
|
|
||||
|
up_block = UpBlock1D( |
||||
|
in_channels=prev_output_channel, |
||||
|
out_channels=output_channel, |
||||
|
) |
||||
|
self.up_blocks.append(up_block) |
||||
|
prev_output_channel = output_channel |
||||
|
|
||||
|
# out |
||||
|
if norm_type == "spatial": |
||||
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) |
||||
|
else: |
||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) |
||||
|
self.conv_act = nn.SiLU() |
||||
|
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) |
||||
|
|
||||
|
self.gradient_checkpointing = False |
||||
|
|
||||
|
|
||||
|
def forward(self, z, latent_embeds=None): |
||||
|
sample = z |
||||
|
sample = self.conv_in(sample) |
||||
|
|
||||
|
if self.training and self.gradient_checkpointing: |
||||
|
|
||||
|
def create_custom_forward(module): |
||||
|
def custom_forward(*inputs): |
||||
|
return module(*inputs) |
||||
|
|
||||
|
return custom_forward |
||||
|
|
||||
|
if is_torch_version(">=", "1.11.0"): |
||||
|
# middle |
||||
|
sample = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False |
||||
|
) |
||||
|
# sample = sample.to(upscale_dtype) |
||||
|
|
||||
|
# up |
||||
|
for up_block in self.up_blocks: |
||||
|
sample = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False |
||||
|
) |
||||
|
else: |
||||
|
# middle |
||||
|
sample = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(self.mid_block), sample, latent_embeds |
||||
|
) |
||||
|
# sample = sample.to(upscale_dtype) |
||||
|
|
||||
|
# up |
||||
|
for up_block in self.up_blocks: |
||||
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) |
||||
|
else: |
||||
|
# middle |
||||
|
sample = self.mid_block(sample, latent_embeds) |
||||
|
# sample = sample.to(upscale_dtype) |
||||
|
# up |
||||
|
for up_block in self.up_blocks: |
||||
|
sample = up_block(sample, latent_embeds) |
||||
|
|
||||
|
# post-process |
||||
|
if latent_embeds is None: |
||||
|
sample = self.conv_norm_out(sample) |
||||
|
else: |
||||
|
sample = self.conv_norm_out(sample, latent_embeds) |
||||
|
sample = self.conv_act(sample) |
||||
|
sample = self.conv_out(sample) |
||||
|
|
||||
|
return sample |
||||
|
|
||||
|
|
||||
|
|
||||
|
class SDFDecoder(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
latent_size, |
||||
|
dims, |
||||
|
dropout=None, |
||||
|
dropout_prob=0.0, |
||||
|
norm_layers=(), |
||||
|
latent_in=(), |
||||
|
weight_norm=False, |
||||
|
xyz_in_all=None, |
||||
|
use_tanh=False, |
||||
|
latent_dropout=False, |
||||
|
): |
||||
|
''' |
||||
|
这是第二阶段的解码器,用于生成SDF值 |
||||
|
使用多层MLP结构 |
||||
|
特点: |
||||
|
支持在不同层注入latent信息(通过latent_in参数) |
||||
|
可以在每层添加xyz坐标(通过xyz_in_all参数) |
||||
|
支持权重归一化和dropout |
||||
|
输入维度: [N, latent_size+3] |
||||
|
输出维度: [N, 1] |
||||
|
|
||||
|
''' |
||||
|
super(SDFDecoder, self).__init__() |
||||
|
|
||||
|
def make_sequence(): |
||||
|
return [] |
||||
|
|
||||
|
dims = [latent_size + 3] + dims + [1] |
||||
|
|
||||
|
self.num_layers = len(dims) |
||||
|
self.norm_layers = norm_layers |
||||
|
self.latent_in = latent_in |
||||
|
self.latent_dropout = latent_dropout |
||||
|
if self.latent_dropout: |
||||
|
self.lat_dp = nn.Dropout(0.2) |
||||
|
|
||||
|
self.xyz_in_all = xyz_in_all |
||||
|
self.weight_norm = weight_norm |
||||
|
|
||||
|
for layer in range(0, self.num_layers - 1): |
||||
|
if layer + 1 in latent_in: |
||||
|
out_dim = dims[layer + 1] - dims[0] |
||||
|
else: |
||||
|
out_dim = dims[layer + 1] |
||||
|
if self.xyz_in_all and layer != self.num_layers - 2: |
||||
|
out_dim -= 3 |
||||
|
|
||||
|
if weight_norm and layer in self.norm_layers: |
||||
|
setattr( |
||||
|
self, |
||||
|
"lin" + str(layer), |
||||
|
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), |
||||
|
) |
||||
|
else: |
||||
|
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) |
||||
|
|
||||
|
if ( |
||||
|
(not weight_norm) |
||||
|
and self.norm_layers is not None |
||||
|
and layer in self.norm_layers |
||||
|
): |
||||
|
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) |
||||
|
|
||||
|
self.use_tanh = use_tanh |
||||
|
if use_tanh: |
||||
|
self.tanh = nn.Tanh() |
||||
|
self.relu = nn.ReLU() |
||||
|
|
||||
|
self.dropout_prob = dropout_prob |
||||
|
self.dropout = dropout |
||||
|
self.th = nn.Tanh() |
||||
|
|
||||
|
# input: N x (L+3) |
||||
|
def forward(self, input): |
||||
|
xyz = input[:, -3:] |
||||
|
|
||||
|
if input.shape[1] > 3 and self.latent_dropout: |
||||
|
latent_vecs = input[:, :-3] |
||||
|
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) |
||||
|
x = torch.cat([latent_vecs, xyz], 1) |
||||
|
else: |
||||
|
x = input |
||||
|
|
||||
|
for layer in range(0, self.num_layers - 1): |
||||
|
lin = getattr(self, "lin" + str(layer)) |
||||
|
if layer in self.latent_in: |
||||
|
x = torch.cat([x, input], 1) |
||||
|
elif layer != 0 and self.xyz_in_all: |
||||
|
x = torch.cat([x, xyz], 1) |
||||
|
x = lin(x) |
||||
|
# last layer Tanh |
||||
|
if layer == self.num_layers - 2 and self.use_tanh: |
||||
|
x = self.tanh(x) |
||||
|
if layer < self.num_layers - 2: |
||||
|
if ( |
||||
|
self.norm_layers is not None |
||||
|
and layer in self.norm_layers |
||||
|
and not self.weight_norm |
||||
|
): |
||||
|
bn = getattr(self, "bn" + str(layer)) |
||||
|
x = bn(x) |
||||
|
x = self.relu(x) |
||||
|
if self.dropout is not None and layer in self.dropout: |
||||
|
x = F.dropout(x, p=self.dropout_prob, training=self.training) |
||||
|
|
||||
|
if hasattr(self, "th"): |
||||
|
x = self.th(x) |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class BRep2SdfDecoder(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
latent_size=256, |
||||
|
feature_dims=[512, 512, 256, 128], # 特征解码器维度 |
||||
|
sdf_dims=[512, 512, 512, 512], # SDF解码器维度 |
||||
|
up_block_types=("UpDecoderBlock2D",), |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
norm_type="group", |
||||
|
dropout=None, |
||||
|
dropout_prob=0.0, |
||||
|
norm_layers=(), |
||||
|
latent_in=(), |
||||
|
weight_norm=False, |
||||
|
xyz_in_all=True, |
||||
|
use_tanh=True, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
# 1. 特征解码器 (使用Decoder1D结构) |
||||
|
self.feature_decoder = Decoder1D( |
||||
|
in_channels=latent_size, |
||||
|
out_channels=feature_dims[-1], |
||||
|
up_block_types=up_block_types, |
||||
|
block_out_channels=feature_dims, |
||||
|
layers_per_block=layers_per_block, |
||||
|
norm_num_groups=norm_num_groups, |
||||
|
norm_type=norm_type |
||||
|
) |
||||
|
|
||||
|
# 2. SDF解码器 (使用原始Decoder结构) |
||||
|
self.sdf_decoder = SDFDecoder( |
||||
|
latent_size=feature_dims[-1], # 使用特征解码器的输出维度 |
||||
|
dims=sdf_dims, |
||||
|
dropout=dropout, |
||||
|
dropout_prob=dropout_prob, |
||||
|
norm_layers=norm_layers, |
||||
|
latent_in=latent_in, |
||||
|
weight_norm=weight_norm, |
||||
|
xyz_in_all=xyz_in_all, |
||||
|
use_tanh=use_tanh, |
||||
|
latent_dropout=False |
||||
|
) |
||||
|
|
||||
|
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) |
||||
|
self.feature_transform = nn.Sequential( |
||||
|
nn.Linear(feature_dims[-1], feature_dims[-1]), |
||||
|
nn.LayerNorm(feature_dims[-1]), |
||||
|
nn.SiLU() |
||||
|
) |
||||
|
|
||||
|
def forward(self, latent, query_points, latent_embeds=None): |
||||
|
""" |
||||
|
Args: |
||||
|
latent: [B, C, L] B-rep特征 |
||||
|
query_points: [B, N, 3] 查询点 |
||||
|
latent_embeds: 可选的条件嵌入 |
||||
|
Returns: |
||||
|
sdf: [B, N, 1] SDF值 |
||||
|
""" |
||||
|
# 1. 特征解码 |
||||
|
features = self.feature_decoder(latent, latent_embeds) # [B, C, L] |
||||
|
|
||||
|
# 2. 特征转换 |
||||
|
B, C, L = features.shape |
||||
|
features = features.permute(0, 2, 1) # [B, L, C] |
||||
|
features = self.feature_transform(features) # [B, L, C] |
||||
|
|
||||
|
# 3. 准备SDF解码器输入 |
||||
|
_, N, _ = query_points.shape |
||||
|
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C] |
||||
|
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3] |
||||
|
|
||||
|
# 4. 合并特征和坐标 |
||||
|
sdf_input = torch.cat([ |
||||
|
features.reshape(B*N*L, -1), # [B*N*L, C] |
||||
|
query_points.reshape(B*N*L, -1) # [B*N*L, 3] |
||||
|
], dim=-1) |
||||
|
|
||||
|
# 5. SDF生成 |
||||
|
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1] |
||||
|
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1] |
||||
|
|
||||
|
# 6. 聚合多尺度SDF |
||||
|
sdf = sdf.mean(dim=2) # [B, N, 1] |
||||
|
|
||||
|
return sdf |
||||
|
|
||||
|
# 使用示例 |
||||
|
if __name__ == "__main__": |
||||
|
# 创建模型 |
||||
|
decoder = BRepDecoder( |
||||
|
latent_size=256, |
||||
|
feature_dims=[512, 256, 128, 64], |
||||
|
sdf_dims=[512, 512, 512, 512], |
||||
|
up_block_types=("UpDecoderBlock2D",), |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
dropout=None, |
||||
|
dropout_prob=0.0, |
||||
|
norm_layers=[0, 1, 2, 3], |
||||
|
latent_in=[4], |
||||
|
weight_norm=True, |
||||
|
xyz_in_all=True, |
||||
|
use_tanh=True |
||||
|
) |
||||
|
|
||||
|
# 测试数据 |
||||
|
batch_size = 4 |
||||
|
seq_len = 32 |
||||
|
num_points = 1000 |
||||
|
|
||||
|
latent = torch.randn(batch_size, 256, seq_len) |
||||
|
query_points = torch.randn(batch_size, num_points, 3) |
||||
|
latent_embeds = torch.randn(batch_size, 256) |
||||
|
|
||||
|
# 前向传播 |
||||
|
sdf = decoder(latent, query_points, latent_embeds) |
||||
|
print(f"Input latent shape: {latent.shape}") |
||||
|
print(f"Query points shape: {query_points.shape}") |
||||
|
print(f"Output SDF shape: {sdf.shape}") |
@ -1,109 +0,0 @@ |
|||||
#!/usr/bin/env python3 |
|
||||
# Copyright 2004-present Facebook. All Rights Reserved. |
|
||||
|
|
||||
import torch.nn as nn |
|
||||
import torch |
|
||||
import torch.nn.functional as F |
|
||||
|
|
||||
|
|
||||
class Decoder(nn.Module): |
|
||||
def __init__( |
|
||||
self, |
|
||||
latent_size, |
|
||||
dims, |
|
||||
dropout=None, |
|
||||
dropout_prob=0.0, |
|
||||
norm_layers=(), |
|
||||
latent_in=(), |
|
||||
weight_norm=False, |
|
||||
xyz_in_all=None, |
|
||||
use_tanh=False, |
|
||||
latent_dropout=False, |
|
||||
): |
|
||||
super(Decoder, self).__init__() |
|
||||
|
|
||||
def make_sequence(): |
|
||||
return [] |
|
||||
|
|
||||
dims = [latent_size + 3] + dims + [1] |
|
||||
|
|
||||
self.num_layers = len(dims) |
|
||||
self.norm_layers = norm_layers |
|
||||
self.latent_in = latent_in |
|
||||
self.latent_dropout = latent_dropout |
|
||||
if self.latent_dropout: |
|
||||
self.lat_dp = nn.Dropout(0.2) |
|
||||
|
|
||||
self.xyz_in_all = xyz_in_all |
|
||||
self.weight_norm = weight_norm |
|
||||
|
|
||||
for layer in range(0, self.num_layers - 1): |
|
||||
if layer + 1 in latent_in: |
|
||||
out_dim = dims[layer + 1] - dims[0] |
|
||||
else: |
|
||||
out_dim = dims[layer + 1] |
|
||||
if self.xyz_in_all and layer != self.num_layers - 2: |
|
||||
out_dim -= 3 |
|
||||
|
|
||||
if weight_norm and layer in self.norm_layers: |
|
||||
setattr( |
|
||||
self, |
|
||||
"lin" + str(layer), |
|
||||
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), |
|
||||
) |
|
||||
else: |
|
||||
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) |
|
||||
|
|
||||
if ( |
|
||||
(not weight_norm) |
|
||||
and self.norm_layers is not None |
|
||||
and layer in self.norm_layers |
|
||||
): |
|
||||
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) |
|
||||
|
|
||||
self.use_tanh = use_tanh |
|
||||
if use_tanh: |
|
||||
self.tanh = nn.Tanh() |
|
||||
self.relu = nn.ReLU() |
|
||||
|
|
||||
self.dropout_prob = dropout_prob |
|
||||
self.dropout = dropout |
|
||||
self.th = nn.Tanh() |
|
||||
|
|
||||
# input: N x (L+3) |
|
||||
def forward(self, input): |
|
||||
xyz = input[:, -3:] |
|
||||
|
|
||||
if input.shape[1] > 3 and self.latent_dropout: |
|
||||
latent_vecs = input[:, :-3] |
|
||||
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) |
|
||||
x = torch.cat([latent_vecs, xyz], 1) |
|
||||
else: |
|
||||
x = input |
|
||||
|
|
||||
for layer in range(0, self.num_layers - 1): |
|
||||
lin = getattr(self, "lin" + str(layer)) |
|
||||
if layer in self.latent_in: |
|
||||
x = torch.cat([x, input], 1) |
|
||||
elif layer != 0 and self.xyz_in_all: |
|
||||
x = torch.cat([x, xyz], 1) |
|
||||
x = lin(x) |
|
||||
# last layer Tanh |
|
||||
if layer == self.num_layers - 2 and self.use_tanh: |
|
||||
x = self.tanh(x) |
|
||||
if layer < self.num_layers - 2: |
|
||||
if ( |
|
||||
self.norm_layers is not None |
|
||||
and layer in self.norm_layers |
|
||||
and not self.weight_norm |
|
||||
): |
|
||||
bn = getattr(self, "bn" + str(layer)) |
|
||||
x = bn(x) |
|
||||
x = self.relu(x) |
|
||||
if self.dropout is not None and layer in self.dropout: |
|
||||
x = F.dropout(x, p=self.dropout_prob, training=self.training) |
|
||||
|
|
||||
if hasattr(self, "th"): |
|
||||
x = self.th(x) |
|
||||
|
|
||||
return x |
|
@ -0,0 +1,356 @@ |
|||||
|
import math |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from dataclasses import dataclass |
||||
|
from typing import Dict, Optional, Tuple, Union |
||||
|
|
||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
||||
|
from diffusers.utils import BaseOutput, is_torch_version |
||||
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
||||
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor |
||||
|
from diffusers.models.modeling_utils import ModelMixin |
||||
|
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
||||
|
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d |
||||
|
from diffusers.models.attention_processor import SpatialNorm |
||||
|
|
||||
|
''' |
||||
|
# NOTE: |
||||
|
移除了分片(slicing)和平铺(tiling)功能 |
||||
|
直接使用mode()而不是sample()获取潜在向量 |
||||
|
简化了编码过程,只保留核心功能 |
||||
|
返回确定性的潜在向量而不是分布 |
||||
|
''' |
||||
|
|
||||
|
# 1. 基础网络组件 |
||||
|
class Embedder(nn.Module): |
||||
|
def __init__(self, vocab_size, d_model): |
||||
|
super().__init__() |
||||
|
self.embed = nn.Embedding(vocab_size, d_model) |
||||
|
self._init_embeddings() |
||||
|
|
||||
|
def _init_embeddings(self): |
||||
|
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.embed(x) |
||||
|
|
||||
|
|
||||
|
class UpBlock1D(nn.Module): |
||||
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
||||
|
super().__init__() |
||||
|
mid_channels = in_channels if mid_channels is None else mid_channels |
||||
|
|
||||
|
resnets = [ |
||||
|
ResConvBlock(in_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, out_channels), |
||||
|
] |
||||
|
|
||||
|
self.resnets = nn.ModuleList(resnets) |
||||
|
self.up = Upsample1d(kernel="cubic") |
||||
|
|
||||
|
def forward(self, hidden_states, temb=None): |
||||
|
for resnet in self.resnets: |
||||
|
hidden_states = resnet(hidden_states) |
||||
|
hidden_states = self.up(hidden_states) |
||||
|
return hidden_states |
||||
|
|
||||
|
|
||||
|
class UNetMidBlock1D(nn.Module): |
||||
|
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): |
||||
|
super().__init__() |
||||
|
|
||||
|
out_channels = in_channels if out_channels is None else out_channels |
||||
|
|
||||
|
# there is always at least one resnet |
||||
|
resnets = [ |
||||
|
ResConvBlock(in_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, mid_channels), |
||||
|
ResConvBlock(mid_channels, mid_channels, out_channels), |
||||
|
] |
||||
|
attentions = [ |
||||
|
SelfAttention1d(mid_channels, mid_channels // 32), |
||||
|
SelfAttention1d(mid_channels, mid_channels // 32), |
||||
|
SelfAttention1d(mid_channels, mid_channels // 32), |
||||
|
SelfAttention1d(mid_channels, mid_channels // 32), |
||||
|
SelfAttention1d(mid_channels, mid_channels // 32), |
||||
|
SelfAttention1d(out_channels, out_channels // 32), |
||||
|
] |
||||
|
|
||||
|
self.attentions = nn.ModuleList(attentions) |
||||
|
self.resnets = nn.ModuleList(resnets) |
||||
|
|
||||
|
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: |
||||
|
for attn, resnet in zip(self.attentions, self.resnets): |
||||
|
hidden_states = resnet(hidden_states) |
||||
|
hidden_states = attn(hidden_states) |
||||
|
|
||||
|
return hidden_states |
||||
|
|
||||
|
|
||||
|
class Encoder1D(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
in_channels=3, |
||||
|
out_channels=3, |
||||
|
down_block_types=("DownEncoderBlock1D",), |
||||
|
block_out_channels=(64,), |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
act_fn="silu", |
||||
|
double_z=True, |
||||
|
): |
||||
|
super().__init__() |
||||
|
self.layers_per_block = layers_per_block |
||||
|
|
||||
|
self.conv_in = torch.nn.Conv1d( |
||||
|
in_channels, |
||||
|
block_out_channels[0], |
||||
|
kernel_size=3, |
||||
|
stride=1, |
||||
|
padding=1, |
||||
|
) |
||||
|
|
||||
|
self.mid_block = None |
||||
|
self.down_blocks = nn.ModuleList([]) |
||||
|
|
||||
|
# down |
||||
|
output_channel = block_out_channels[0] |
||||
|
for i, down_block_type in enumerate(down_block_types): |
||||
|
input_channel = output_channel |
||||
|
output_channel = block_out_channels[i] |
||||
|
is_final_block = i == len(block_out_channels) - 1 |
||||
|
|
||||
|
down_block = get_down_block( |
||||
|
down_block_type, |
||||
|
num_layers=self.layers_per_block, |
||||
|
in_channels=input_channel, |
||||
|
out_channels=output_channel, |
||||
|
add_downsample=not is_final_block, |
||||
|
temb_channels=None, |
||||
|
) |
||||
|
self.down_blocks.append(down_block) |
||||
|
|
||||
|
# mid |
||||
|
self.mid_block = UNetMidBlock1D( |
||||
|
in_channels=block_out_channels[-1], |
||||
|
mid_channels=block_out_channels[-1], |
||||
|
) |
||||
|
|
||||
|
# out |
||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) |
||||
|
self.conv_act = nn.SiLU() |
||||
|
|
||||
|
conv_out_channels = 2 * out_channels if double_z else out_channels |
||||
|
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1) |
||||
|
|
||||
|
self.gradient_checkpointing = False |
||||
|
|
||||
|
def forward(self, x): |
||||
|
sample = x |
||||
|
sample = self.conv_in(sample) |
||||
|
|
||||
|
if self.training and self.gradient_checkpointing: |
||||
|
|
||||
|
def create_custom_forward(module): |
||||
|
def custom_forward(*inputs): |
||||
|
return module(*inputs) |
||||
|
|
||||
|
return custom_forward |
||||
|
|
||||
|
# down |
||||
|
if is_torch_version(">=", "1.11.0"): |
||||
|
for down_block in self.down_blocks: |
||||
|
sample = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(down_block), sample, use_reentrant=False |
||||
|
) |
||||
|
|
||||
|
# middle |
||||
|
sample = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(self.mid_block), sample, use_reentrant=False |
||||
|
) |
||||
|
else: |
||||
|
for down_block in self.down_blocks: |
||||
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) |
||||
|
# middle |
||||
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) |
||||
|
|
||||
|
else: |
||||
|
# down |
||||
|
for down_block in self.down_blocks: |
||||
|
sample = down_block(sample)[0] |
||||
|
|
||||
|
# middle |
||||
|
sample = self.mid_block(sample) |
||||
|
|
||||
|
# post-process |
||||
|
sample = self.conv_norm_out(sample) |
||||
|
sample = self.conv_act(sample) |
||||
|
sample = self.conv_out(sample) |
||||
|
return sample |
||||
|
|
||||
|
# 2. B-rep特征处理 |
||||
|
class BRepFeatureExtractor: |
||||
|
def __init__(self, config): |
||||
|
self.encoder = Encoder1D( |
||||
|
in_channels=config.in_channels, # 根据特征维度设置 |
||||
|
out_channels=config.out_channels, |
||||
|
block_out_channels=config.block_out_channels, |
||||
|
layers_per_block=config.layers_per_block |
||||
|
) |
||||
|
|
||||
|
def extract_face_features(self, face): |
||||
|
"""提取面的特征""" |
||||
|
features = [] |
||||
|
try: |
||||
|
# 基本几何特征 |
||||
|
center = face.center_point |
||||
|
normal = face.normal_vector |
||||
|
|
||||
|
# 边界特征 |
||||
|
bounds = face.bounds() |
||||
|
|
||||
|
# 曲面类型特征 |
||||
|
surface_type = face.surface_type |
||||
|
|
||||
|
# 组合特征 |
||||
|
feature = np.concatenate([ |
||||
|
center, # [3] |
||||
|
normal, # [3] |
||||
|
bounds.flatten(), |
||||
|
[surface_type] # 可以用one-hot编码 |
||||
|
]) |
||||
|
features.append(feature) |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"Error extracting face features: {e}") |
||||
|
|
||||
|
return np.array(features) |
||||
|
|
||||
|
def extract_edge_features(self, edge): |
||||
|
"""提取边的特征""" |
||||
|
features = [] |
||||
|
try: |
||||
|
# 采样点 |
||||
|
points = self.sample_points_on_edge(edge) |
||||
|
|
||||
|
for point in points: |
||||
|
# 位置 |
||||
|
pos = point.coordinates |
||||
|
# 切向量 |
||||
|
tangent = point.tangent |
||||
|
# 曲率 |
||||
|
curvature = point.curvature |
||||
|
|
||||
|
point_feature = np.concatenate([ |
||||
|
pos, # [3] |
||||
|
tangent, # [3] |
||||
|
[curvature] # [1] |
||||
|
]) |
||||
|
features.append(point_feature) |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"Error extracting edge features: {e}") |
||||
|
|
||||
|
return np.array(features) |
||||
|
|
||||
|
@staticmethod |
||||
|
def sample_points_on_edge(edge, num_points=32): |
||||
|
"""在边上均匀采样点""" |
||||
|
points = [] |
||||
|
try: |
||||
|
length = edge.length() |
||||
|
for i in range(num_points): |
||||
|
t = i / (num_points - 1) |
||||
|
point = edge.point_at(t * length) |
||||
|
points.append(point) |
||||
|
except Exception as e: |
||||
|
print(f"Error sampling points: {e}") |
||||
|
return points |
||||
|
|
||||
|
class BRepDataProcessor: |
||||
|
def __init__(self, feature_extractor): |
||||
|
self.feature_extractor = feature_extractor |
||||
|
|
||||
|
def process_brep(self, brep_model): |
||||
|
"""处理单个B-rep模型""" |
||||
|
try: |
||||
|
# 1. 提取面特征 |
||||
|
face_features = [] |
||||
|
for face in brep_model.faces: |
||||
|
feat = self.feature_extractor.extract_face_features(face) |
||||
|
face_features.append(feat) |
||||
|
|
||||
|
# 2. 提取边特征 |
||||
|
edge_features = [] |
||||
|
for edge in brep_model.edges: |
||||
|
feat = self.feature_extractor.extract_edge_features(edge) |
||||
|
edge_features.append(feat) |
||||
|
|
||||
|
# 3. 组织数据结构 |
||||
|
return { |
||||
|
'face_features': torch.tensor(face_features), |
||||
|
'edge_features': torch.tensor(edge_features), |
||||
|
'topology': self.extract_topology(brep_model) |
||||
|
} |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"Error processing B-rep: {e}") |
||||
|
return None |
||||
|
|
||||
|
def extract_topology(self, brep_model): |
||||
|
"""提取拓扑关系""" |
||||
|
# 面-边关系矩阵 |
||||
|
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges))) |
||||
|
# 填充邻接关系 |
||||
|
for i, face in enumerate(brep_model.faces): |
||||
|
for j, edge in enumerate(brep_model.edges): |
||||
|
if edge in face.edges: |
||||
|
face_edge_adj[i,j] = 1 |
||||
|
return face_edge_adj |
||||
|
|
||||
|
# 3. 主编码器 |
||||
|
class BRepEncoder: |
||||
|
def __init__(self, config): |
||||
|
self.processor = BRepDataProcessor( |
||||
|
BRepFeatureExtractor(config) |
||||
|
) |
||||
|
self.encoder = Encoder1D(**config.encoder_params) |
||||
|
|
||||
|
def encode(self, brep_model): |
||||
|
"""编码B-rep模型""" |
||||
|
try: |
||||
|
# 1. 处理原始数据 |
||||
|
processed_data = self.processor.process_brep(brep_model) |
||||
|
if processed_data is None: |
||||
|
return None |
||||
|
|
||||
|
# 2. 特征编码 |
||||
|
face_features = self.encoder(processed_data['face_features']) |
||||
|
edge_features = self.encoder(processed_data['edge_features']) |
||||
|
|
||||
|
# 3. 组合特征 |
||||
|
combined_features = self.combine_features( |
||||
|
face_features, |
||||
|
edge_features, |
||||
|
processed_data['topology'] |
||||
|
) |
||||
|
|
||||
|
return combined_features |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"Error encoding B-rep: {e}") |
||||
|
return None |
||||
|
|
||||
|
def combine_features(self, face_features, edge_features, topology): |
||||
|
"""组合不同类型的特征""" |
||||
|
# 可以使用图神经网络或者注意力机制来组合特征 |
||||
|
combined = torch.cat([ |
||||
|
face_features.mean(dim=1), # 全局面特征 |
||||
|
edge_features.mean(dim=1), # 全局边特征 |
||||
|
topology.flatten() # 拓扑信息 |
||||
|
], dim=-1) |
||||
|
return combined |
@ -0,0 +1,127 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from encoder import BRepEncoder |
||||
|
from decoder import BRep2SdfDecoder |
||||
|
|
||||
|
class BRep2SDF(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
# 编码器参数 |
||||
|
in_channels=3, |
||||
|
latent_size=256, |
||||
|
encoder_block_out_channels=(512, 256, 128, 64), |
||||
|
# 解码器参数 |
||||
|
decoder_feature_dims=(512, 256, 128, 64), |
||||
|
sdf_dims=(512, 512, 512, 512), |
||||
|
# 共享参数 |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
# SDF特定参数 |
||||
|
dropout=None, |
||||
|
dropout_prob=0.0, |
||||
|
norm_layers=(0, 1, 2, 3), |
||||
|
latent_in=(4,), |
||||
|
weight_norm=True, |
||||
|
xyz_in_all=True, |
||||
|
use_tanh=True, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
# 1. 编码器配置 |
||||
|
encoder_config = type('Config', (), { |
||||
|
'in_channels': in_channels, |
||||
|
'out_channels': latent_size, |
||||
|
'block_out_channels': encoder_block_out_channels, |
||||
|
'layers_per_block': layers_per_block, |
||||
|
'norm_num_groups': norm_num_groups, |
||||
|
'encoder_params': { |
||||
|
'in_channels': in_channels, |
||||
|
'out_channels': latent_size, |
||||
|
'block_out_channels': encoder_block_out_channels, |
||||
|
'layers_per_block': layers_per_block, |
||||
|
'norm_num_groups': norm_num_groups, |
||||
|
} |
||||
|
})() |
||||
|
|
||||
|
# 2. 解码器配置 |
||||
|
decoder_config = { |
||||
|
'latent_size': latent_size, |
||||
|
'feature_dims': decoder_feature_dims, |
||||
|
'sdf_dims': sdf_dims, |
||||
|
'layers_per_block': layers_per_block, |
||||
|
'norm_num_groups': norm_num_groups, |
||||
|
'dropout': dropout, |
||||
|
'dropout_prob': dropout_prob, |
||||
|
'norm_layers': norm_layers, |
||||
|
'latent_in': latent_in, |
||||
|
'weight_norm': weight_norm, |
||||
|
'xyz_in_all': xyz_in_all, |
||||
|
'use_tanh': use_tanh, |
||||
|
} |
||||
|
|
||||
|
# 3. 创建编码器和解码器 |
||||
|
self.encoder = BRepEncoder(encoder_config) |
||||
|
self.decoder = BRep2SdfDecoder(**decoder_config) |
||||
|
|
||||
|
def encode(self, brep_model): |
||||
|
"""编码B-rep模型为潜在特征""" |
||||
|
return self.encoder.encode(brep_model) |
||||
|
|
||||
|
def decode(self, latent, query_points, latent_embeds=None): |
||||
|
"""从潜在特征解码SDF值""" |
||||
|
return self.decoder(latent, query_points, latent_embeds) |
||||
|
|
||||
|
def forward(self, brep_model, query_points): |
||||
|
"""完整的前向传播过程""" |
||||
|
# 1. 编码B-rep模型 |
||||
|
latent = self.encode(brep_model) |
||||
|
if latent is None: |
||||
|
return None |
||||
|
|
||||
|
# 2. 解码SDF值 |
||||
|
sdf = self.decode(latent, query_points) |
||||
|
return sdf |
||||
|
|
||||
|
# 使用示例 |
||||
|
if __name__ == "__main__": |
||||
|
# 创建模型 |
||||
|
model = BRep2SDF( |
||||
|
in_channels=3, |
||||
|
latent_size=256, |
||||
|
encoder_block_out_channels=(512, 256, 128, 64), |
||||
|
decoder_feature_dims=(512, 256, 128, 64), |
||||
|
sdf_dims=(512, 512, 512, 512), |
||||
|
layers_per_block=2, |
||||
|
norm_num_groups=32, |
||||
|
) |
||||
|
|
||||
|
# 测试数据 |
||||
|
batch_size = 4 |
||||
|
seq_len = 32 |
||||
|
num_points = 1000 |
||||
|
|
||||
|
# 模拟B-rep模型数据 |
||||
|
class MockBRep: |
||||
|
def __init__(self): |
||||
|
self.faces = [MockFace() for _ in range(10)] |
||||
|
self.edges = [MockEdge() for _ in range(20)] |
||||
|
|
||||
|
class MockFace: |
||||
|
def __init__(self): |
||||
|
self.center_point = torch.randn(3) |
||||
|
self.normal_vector = torch.randn(3) |
||||
|
self.surface_type = 0 |
||||
|
self.edges = [] |
||||
|
|
||||
|
class MockEdge: |
||||
|
def __init__(self): |
||||
|
self.length = lambda: 1.0 |
||||
|
self.point_at = lambda t: torch.randn(3) |
||||
|
|
||||
|
brep_model = MockBRep() |
||||
|
query_points = torch.randn(batch_size, num_points, 3) |
||||
|
|
||||
|
# 前向传播 |
||||
|
sdf = model(brep_model, query_points) |
||||
|
if sdf is not None: |
||||
|
print(f"Output SDF shape: {sdf.shape}") |
Loading…
Reference in new issue