import math import torch import torch.nn as nn from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.modeling_utils import ModelMixin from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d from diffusers.models.attention_processor import SpatialNorm ''' # NOTE: 移除了分片(slicing)和平铺(tiling)功能 直接使用mode()而不是sample()获取潜在向量 简化了编码过程,只保留核心功能 返回确定性的潜在向量而不是分布 ''' # 1. 基础网络组件 class Embedder(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embed = nn.Embedding(vocab_size, d_model) self._init_embeddings() def _init_embeddings(self): nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") def forward(self, x): return self.embed(x) class UpBlock1D(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward(self, hidden_states, temb=None): for resnet in self.resnets: hidden_states = resnet(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): super().__init__() out_channels = in_channels if out_channels is None else out_channels # there is always at least one resnet resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) return hidden_states class Encoder1D(nn.Module): def __init__( self, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock1D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", double_z=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = torch.nn.Conv1d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock1D( in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, x): sample = x sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, use_reentrant=False ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, use_reentrant=False ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down for down_block in self.down_blocks: sample = down_block(sample)[0] # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample # 2. B-rep特征处理 class BRepFeatureExtractor: def __init__(self, config): self.encoder = Encoder1D( in_channels=config.in_channels, # 根据特征维度设置 out_channels=config.out_channels, block_out_channels=config.block_out_channels, layers_per_block=config.layers_per_block ) def extract_face_features(self, face): """提取面的特征""" features = [] try: # 基本几何特征 center = face.center_point normal = face.normal_vector # 边界特征 bounds = face.bounds() # 曲面类型特征 surface_type = face.surface_type # 组合特征 feature = np.concatenate([ center, # [3] normal, # [3] bounds.flatten(), [surface_type] # 可以用one-hot编码 ]) features.append(feature) except Exception as e: print(f"Error extracting face features: {e}") return np.array(features) def extract_edge_features(self, edge): """提取边的特征""" features = [] try: # 采样点 points = self.sample_points_on_edge(edge) for point in points: # 位置 pos = point.coordinates # 切向量 tangent = point.tangent # 曲率 curvature = point.curvature point_feature = np.concatenate([ pos, # [3] tangent, # [3] [curvature] # [1] ]) features.append(point_feature) except Exception as e: print(f"Error extracting edge features: {e}") return np.array(features) @staticmethod def sample_points_on_edge(edge, num_points=32): """在边上均匀采样点""" points = [] try: length = edge.length() for i in range(num_points): t = i / (num_points - 1) point = edge.point_at(t * length) points.append(point) except Exception as e: print(f"Error sampling points: {e}") return points class BRepDataProcessor: def __init__(self, feature_extractor): self.feature_extractor = feature_extractor def process_brep(self, brep_model): """处理单个B-rep模型""" try: # 1. 提取面特征 face_features = [] for face in brep_model.faces: feat = self.feature_extractor.extract_face_features(face) face_features.append(feat) # 2. 提取边特征 edge_features = [] for edge in brep_model.edges: feat = self.feature_extractor.extract_edge_features(edge) edge_features.append(feat) # 3. 组织数据结构 return { 'face_features': torch.tensor(face_features), 'edge_features': torch.tensor(edge_features), 'topology': self.extract_topology(brep_model) } except Exception as e: print(f"Error processing B-rep: {e}") return None def extract_topology(self, brep_model): """提取拓扑关系""" # 面-边关系矩阵 face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges))) # 填充邻接关系 for i, face in enumerate(brep_model.faces): for j, edge in enumerate(brep_model.edges): if edge in face.edges: face_edge_adj[i,j] = 1 return face_edge_adj # 3. 主编码器 class BRepEncoder: def __init__(self, config): self.processor = BRepDataProcessor( BRepFeatureExtractor(config) ) self.encoder = Encoder1D(**config.encoder_params) def encode(self, brep_model): """编码B-rep模型""" try: # 1. 处理原始数据 processed_data = self.processor.process_brep(brep_model) if processed_data is None: return None # 2. 特征编码 face_features = self.encoder(processed_data['face_features']) edge_features = self.encoder(processed_data['edge_features']) # 3. 组合特征 combined_features = self.combine_features( face_features, edge_features, processed_data['topology'] ) return combined_features except Exception as e: print(f"Error encoding B-rep: {e}") return None def combine_features(self, face_features, edge_features, topology): """组合不同类型的特征""" # 可以使用图神经网络或者注意力机制来组合特征 combined = torch.cat([ face_features.mean(dim=1), # 全局面特征 edge_features.mean(dim=1), # 全局边特征 topology.flatten() # 拓扑信息 ], dim=-1) return combined