You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

356 lines
12 KiB

import math
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
'''
# NOTE:
移除了分片(slicing)和平铺(tiling)功能
直接使用mode()而不是sample()获取潜在向量
简化了编码过程,只保留核心功能
返回确定性的潜在向量而不是分布
'''
# 1. 基础网络组件
class Embedder(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self._init_embeddings()
def _init_embeddings(self):
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
def forward(self, x):
return self.embed(x)
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states
class Encoder1D(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock1D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv1d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x):
sample = x
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)[0]
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
# 2. B-rep特征处理
class BRepFeatureExtractor:
def __init__(self, config):
self.encoder = Encoder1D(
in_channels=config.in_channels, # 根据特征维度设置
out_channels=config.out_channels,
block_out_channels=config.block_out_channels,
layers_per_block=config.layers_per_block
)
def extract_face_features(self, face):
"""提取面的特征"""
features = []
try:
# 基本几何特征
center = face.center_point
normal = face.normal_vector
# 边界特征
bounds = face.bounds()
# 曲面类型特征
surface_type = face.surface_type
# 组合特征
feature = np.concatenate([
center, # [3]
normal, # [3]
bounds.flatten(),
[surface_type] # 可以用one-hot编码
])
features.append(feature)
except Exception as e:
print(f"Error extracting face features: {e}")
return np.array(features)
def extract_edge_features(self, edge):
"""提取边的特征"""
features = []
try:
# 采样点
points = self.sample_points_on_edge(edge)
for point in points:
# 位置
pos = point.coordinates
# 切向量
tangent = point.tangent
# 曲率
curvature = point.curvature
point_feature = np.concatenate([
pos, # [3]
tangent, # [3]
[curvature] # [1]
])
features.append(point_feature)
except Exception as e:
print(f"Error extracting edge features: {e}")
return np.array(features)
@staticmethod
def sample_points_on_edge(edge, num_points=32):
"""在边上均匀采样点"""
points = []
try:
length = edge.length()
for i in range(num_points):
t = i / (num_points - 1)
point = edge.point_at(t * length)
points.append(point)
except Exception as e:
print(f"Error sampling points: {e}")
return points
class BRepDataProcessor:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def process_brep(self, brep_model):
"""处理单个B-rep模型"""
try:
# 1. 提取面特征
face_features = []
for face in brep_model.faces:
feat = self.feature_extractor.extract_face_features(face)
face_features.append(feat)
# 2. 提取边特征
edge_features = []
for edge in brep_model.edges:
feat = self.feature_extractor.extract_edge_features(edge)
edge_features.append(feat)
# 3. 组织数据结构
return {
'face_features': torch.tensor(face_features),
'edge_features': torch.tensor(edge_features),
'topology': self.extract_topology(brep_model)
}
except Exception as e:
print(f"Error processing B-rep: {e}")
return None
def extract_topology(self, brep_model):
"""提取拓扑关系"""
# 面-边关系矩阵
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges)))
# 填充邻接关系
for i, face in enumerate(brep_model.faces):
for j, edge in enumerate(brep_model.edges):
if edge in face.edges:
face_edge_adj[i,j] = 1
return face_edge_adj
# 3. 主编码器
class BRepEncoder:
def __init__(self, config):
self.processor = BRepDataProcessor(
BRepFeatureExtractor(config)
)
self.encoder = Encoder1D(**config.encoder_params)
def encode(self, brep_model):
"""编码B-rep模型"""
try:
# 1. 处理原始数据
processed_data = self.processor.process_brep(brep_model)
if processed_data is None:
return None
# 2. 特征编码
face_features = self.encoder(processed_data['face_features'])
edge_features = self.encoder(processed_data['edge_features'])
# 3. 组合特征
combined_features = self.combine_features(
face_features,
edge_features,
processed_data['topology']
)
return combined_features
except Exception as e:
print(f"Error encoding B-rep: {e}")
return None
def combine_features(self, face_features, edge_features, topology):
"""组合不同类型的特征"""
# 可以使用图神经网络或者注意力机制来组合特征
combined = torch.cat([
face_features.mean(dim=1), # 全局面特征
edge_features.mean(dim=1), # 全局边特征
topology.flatten() # 拓扑信息
], dim=-1)
return combined