You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
356 lines
12 KiB
356 lines
12 KiB
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.utils import BaseOutput, is_torch_version
|
|
from diffusers.utils.accelerate_utils import apply_forward_hook
|
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
|
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
|
|
from diffusers.models.attention_processor import SpatialNorm
|
|
|
|
'''
|
|
# NOTE:
|
|
移除了分片(slicing)和平铺(tiling)功能
|
|
直接使用mode()而不是sample()获取潜在向量
|
|
简化了编码过程,只保留核心功能
|
|
返回确定性的潜在向量而不是分布
|
|
'''
|
|
|
|
# 1. 基础网络组件
|
|
class Embedder(nn.Module):
|
|
def __init__(self, vocab_size, d_model):
|
|
super().__init__()
|
|
self.embed = nn.Embedding(vocab_size, d_model)
|
|
self._init_embeddings()
|
|
|
|
def _init_embeddings(self):
|
|
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
|
|
|
|
def forward(self, x):
|
|
return self.embed(x)
|
|
|
|
|
|
class UpBlock1D(nn.Module):
|
|
def __init__(self, in_channels, out_channels, mid_channels=None):
|
|
super().__init__()
|
|
mid_channels = in_channels if mid_channels is None else mid_channels
|
|
|
|
resnets = [
|
|
ResConvBlock(in_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, out_channels),
|
|
]
|
|
|
|
self.resnets = nn.ModuleList(resnets)
|
|
self.up = Upsample1d(kernel="cubic")
|
|
|
|
def forward(self, hidden_states, temb=None):
|
|
for resnet in self.resnets:
|
|
hidden_states = resnet(hidden_states)
|
|
hidden_states = self.up(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class UNetMidBlock1D(nn.Module):
|
|
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
|
super().__init__()
|
|
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
|
|
# there is always at least one resnet
|
|
resnets = [
|
|
ResConvBlock(in_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
ResConvBlock(mid_channels, mid_channels, out_channels),
|
|
]
|
|
attentions = [
|
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
|
SelfAttention1d(out_channels, out_channels // 32),
|
|
]
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
|
for attn, resnet in zip(self.attentions, self.resnets):
|
|
hidden_states = resnet(hidden_states)
|
|
hidden_states = attn(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Encoder1D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
out_channels=3,
|
|
down_block_types=("DownEncoderBlock1D",),
|
|
block_out_channels=(64,),
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
act_fn="silu",
|
|
double_z=True,
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
|
|
self.conv_in = torch.nn.Conv1d(
|
|
in_channels,
|
|
block_out_channels[0],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
|
|
self.mid_block = None
|
|
self.down_blocks = nn.ModuleList([])
|
|
|
|
# down
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
down_block = get_down_block(
|
|
down_block_type,
|
|
num_layers=self.layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=not is_final_block,
|
|
temb_channels=None,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlock1D(
|
|
in_channels=block_out_channels[-1],
|
|
mid_channels=block_out_channels[-1],
|
|
)
|
|
|
|
# out
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
|
self.conv_act = nn.SiLU()
|
|
|
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
|
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(self, x):
|
|
sample = x
|
|
sample = self.conv_in(sample)
|
|
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
# down
|
|
if is_torch_version(">=", "1.11.0"):
|
|
for down_block in self.down_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(down_block), sample, use_reentrant=False
|
|
)
|
|
|
|
# middle
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
|
)
|
|
else:
|
|
for down_block in self.down_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
|
# middle
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
|
|
|
else:
|
|
# down
|
|
for down_block in self.down_blocks:
|
|
sample = down_block(sample)[0]
|
|
|
|
# middle
|
|
sample = self.mid_block(sample)
|
|
|
|
# post-process
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
return sample
|
|
|
|
# 2. B-rep特征处理
|
|
class BRepFeatureExtractor:
|
|
def __init__(self, config):
|
|
self.encoder = Encoder1D(
|
|
in_channels=config.in_channels, # 根据特征维度设置
|
|
out_channels=config.out_channels,
|
|
block_out_channels=config.block_out_channels,
|
|
layers_per_block=config.layers_per_block
|
|
)
|
|
|
|
def extract_face_features(self, face):
|
|
"""提取面的特征"""
|
|
features = []
|
|
try:
|
|
# 基本几何特征
|
|
center = face.center_point
|
|
normal = face.normal_vector
|
|
|
|
# 边界特征
|
|
bounds = face.bounds()
|
|
|
|
# 曲面类型特征
|
|
surface_type = face.surface_type
|
|
|
|
# 组合特征
|
|
feature = np.concatenate([
|
|
center, # [3]
|
|
normal, # [3]
|
|
bounds.flatten(),
|
|
[surface_type] # 可以用one-hot编码
|
|
])
|
|
features.append(feature)
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting face features: {e}")
|
|
|
|
return np.array(features)
|
|
|
|
def extract_edge_features(self, edge):
|
|
"""提取边的特征"""
|
|
features = []
|
|
try:
|
|
# 采样点
|
|
points = self.sample_points_on_edge(edge)
|
|
|
|
for point in points:
|
|
# 位置
|
|
pos = point.coordinates
|
|
# 切向量
|
|
tangent = point.tangent
|
|
# 曲率
|
|
curvature = point.curvature
|
|
|
|
point_feature = np.concatenate([
|
|
pos, # [3]
|
|
tangent, # [3]
|
|
[curvature] # [1]
|
|
])
|
|
features.append(point_feature)
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting edge features: {e}")
|
|
|
|
return np.array(features)
|
|
|
|
@staticmethod
|
|
def sample_points_on_edge(edge, num_points=32):
|
|
"""在边上均匀采样点"""
|
|
points = []
|
|
try:
|
|
length = edge.length()
|
|
for i in range(num_points):
|
|
t = i / (num_points - 1)
|
|
point = edge.point_at(t * length)
|
|
points.append(point)
|
|
except Exception as e:
|
|
print(f"Error sampling points: {e}")
|
|
return points
|
|
|
|
class BRepDataProcessor:
|
|
def __init__(self, feature_extractor):
|
|
self.feature_extractor = feature_extractor
|
|
|
|
def process_brep(self, brep_model):
|
|
"""处理单个B-rep模型"""
|
|
try:
|
|
# 1. 提取面特征
|
|
face_features = []
|
|
for face in brep_model.faces:
|
|
feat = self.feature_extractor.extract_face_features(face)
|
|
face_features.append(feat)
|
|
|
|
# 2. 提取边特征
|
|
edge_features = []
|
|
for edge in brep_model.edges:
|
|
feat = self.feature_extractor.extract_edge_features(edge)
|
|
edge_features.append(feat)
|
|
|
|
# 3. 组织数据结构
|
|
return {
|
|
'face_features': torch.tensor(face_features),
|
|
'edge_features': torch.tensor(edge_features),
|
|
'topology': self.extract_topology(brep_model)
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"Error processing B-rep: {e}")
|
|
return None
|
|
|
|
def extract_topology(self, brep_model):
|
|
"""提取拓扑关系"""
|
|
# 面-边关系矩阵
|
|
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges)))
|
|
# 填充邻接关系
|
|
for i, face in enumerate(brep_model.faces):
|
|
for j, edge in enumerate(brep_model.edges):
|
|
if edge in face.edges:
|
|
face_edge_adj[i,j] = 1
|
|
return face_edge_adj
|
|
|
|
# 3. 主编码器
|
|
class BRepEncoder:
|
|
def __init__(self, config):
|
|
self.processor = BRepDataProcessor(
|
|
BRepFeatureExtractor(config)
|
|
)
|
|
self.encoder = Encoder1D(**config.encoder_params)
|
|
|
|
def encode(self, brep_model):
|
|
"""编码B-rep模型"""
|
|
try:
|
|
# 1. 处理原始数据
|
|
processed_data = self.processor.process_brep(brep_model)
|
|
if processed_data is None:
|
|
return None
|
|
|
|
# 2. 特征编码
|
|
face_features = self.encoder(processed_data['face_features'])
|
|
edge_features = self.encoder(processed_data['edge_features'])
|
|
|
|
# 3. 组合特征
|
|
combined_features = self.combine_features(
|
|
face_features,
|
|
edge_features,
|
|
processed_data['topology']
|
|
)
|
|
|
|
return combined_features
|
|
|
|
except Exception as e:
|
|
print(f"Error encoding B-rep: {e}")
|
|
return None
|
|
|
|
def combine_features(self, face_features, edge_features, topology):
|
|
"""组合不同类型的特征"""
|
|
# 可以使用图神经网络或者注意力机制来组合特征
|
|
combined = torch.cat([
|
|
face_features.mean(dim=1), # 全局面特征
|
|
edge_features.mean(dim=1), # 全局边特征
|
|
topology.flatten() # 拓扑信息
|
|
], dim=-1)
|
|
return combined
|