From ddc4808d02b3ca40af60afa5347073b099a845b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=90=9B=E6=B6=B5?= Date: Tue, 12 Nov 2024 23:19:08 +0800 Subject: [PATCH] =?UTF-8?q?encode.py=E5=8F=AF=E4=BB=A5=E8=BF=90=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/encoder.py | 624 ++++++++++++++++++++++---------------------- 1 file changed, 313 insertions(+), 311 deletions(-) diff --git a/networks/encoder.py b/networks/encoder.py index 6d85c8f..5a956ce 100644 --- a/networks/encoder.py +++ b/networks/encoder.py @@ -1,356 +1,358 @@ import math import torch import torch.nn as nn +import torch.nn.functional as F from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, is_torch_version -from diffusers.utils.accelerate_utils import apply_forward_hook -from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d -from diffusers.models.attention_processor import SpatialNorm - -''' -# NOTE: - 移除了分片(slicing)和平铺(tiling)功能 - 直接使用mode()而不是sample()获取潜在向量 - 简化了编码过程,只保留核心功能 - 返回确定性的潜在向量而不是分布 -''' - -# 1. 基础网络组件 -class Embedder(nn.Module): - def __init__(self, vocab_size, d_model): +class ResConvBlock(nn.Module): + """残差卷积块""" + def __init__(self, in_channels: int, mid_channels: int, out_channels: int): super().__init__() - self.embed = nn.Embedding(vocab_size, d_model) - self._init_embeddings() - - def _init_embeddings(self): - nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") + self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1) + self.norm2 = nn.GroupNorm(32, out_channels) + self.act = nn.SiLU() + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x): - return self.embed(x) - + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.act(x) + x = self.conv2(x) + x = self.norm2(x) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return self.act(x + residual) -class UpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): +class SelfAttention1d(nn.Module): + """一维自注意力层""" + def __init__(self, channels: int, num_head_channels: int): super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") - - def forward(self, hidden_states, temb=None): - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - hidden_states = self.up(hidden_states) - return hidden_states + self.num_heads = channels // num_head_channels + self.scale = num_head_channels ** -0.5 + + self.qkv = nn.Conv1d(channels, channels * 3, 1) + self.proj = nn.Conv1d(channels, channels, 1) + def forward(self, x): + b, c, l = x.shape + qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l) + q, k, v = qkv.unbind(1) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).reshape(b, c, l) + x = self.proj(x) + return x class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): + """U-Net中间块""" + def __init__(self, in_channels: int, mid_channels: int): super().__init__() - - out_channels = in_channels if out_channels is None else out_channels - - # there is always at least one resnet - resnets = [ + self.resnets = nn.ModuleList([ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) + ResConvBlock(mid_channels, mid_channels, in_channels), + ]) + self.attentions = nn.ModuleList([ + SelfAttention1d(mid_channels, mid_channels // 32) + for _ in range(3) + ]) - def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + def forward(self, x): for attn, resnet in zip(self.attentions, self.resnets): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - - return hidden_states - + x = resnet(x) + x = attn(x) + return x class Encoder1D(nn.Module): + """一维编码器""" def __init__( self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock1D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, + in_channels: int = 3, + out_channels: int = 256, + block_out_channels: Tuple[int] = (64, 128, 256), + layers_per_block: int = 2, ): super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = torch.nn.Conv1d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - ) - - self.mid_block = None + self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) + self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - temb_channels=None, - ) - self.down_blocks.append(down_block) - - # mid + in_ch = block_out_channels[0] + for out_ch in block_out_channels: + block = [] + for _ in range(layers_per_block): + block.append(ResConvBlock(in_ch, out_ch, out_ch)) + in_ch = out_ch + if out_ch != block_out_channels[-1]: + block.append(nn.AvgPool1d(2)) + self.down_blocks.append(nn.Sequential(*block)) + self.mid_block = UNetMidBlock1D( in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], ) - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - self.gradient_checkpointing = False + self.conv_out = nn.Sequential( + nn.GroupNorm(32, block_out_channels[-1]), + nn.SiLU(), + nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), + ) def forward(self, x): - sample = x - sample = self.conv_in(sample) + x = self.conv_in(x) + for block in self.down_blocks: + x = block(x) + x = self.mid_block(x) + x = self.conv_out(x) + return x + +class BRepFeatureEmbedder(nn.Module): + """B-rep特征嵌入器""" + def __init__(self, use_cf: bool = True): + super().__init__() + self.embed_dim = 768 + self.use_cf = use_cf - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - if is_torch_version(">=", "1.11.0"): - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, use_reentrant=False - ) - else: - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + layer = nn.TransformerEncoderLayer( + d_model=self.embed_dim, + nhead=12, + norm_first=False, + dim_feedforward=1024, + dropout=0.1 + ) + self.transformer = nn.TransformerEncoder( + layer, + num_layers=12, + norm=nn.LayerNorm(self.embed_dim), + enable_nested_tensor=False # 添加这个参数 + ) + + self.surfz_embed = nn.Sequential( + nn.Linear(3*16, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + + self.edgez_embed = nn.Sequential( + nn.Linear(3*4, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + + self.surfp_embed = nn.Sequential( + nn.Linear(6, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + + self.edgep_embed = nn.Sequential( + nn.Linear(6, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + + self.vertp_embed = nn.Sequential( + nn.Linear(6, self.embed_dim), + nn.LayerNorm(self.embed_dim), + nn.SiLU(), + nn.Linear(self.embed_dim, self.embed_dim), + ) + def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): + # 特征嵌入 + surf_embeds = self.surfz_embed(surf_z) + edge_embeds = self.edgez_embed(edge_z) + + # 点嵌入 + surf_p_embeds = self.surfp_embed(surf_p) + edge_p_embeds = self.edgep_embed(edge_p) + vert_p_embeds = self.vertp_embed(vert_p) + + # 组合所有嵌入 + if self.use_cf: + embeds = torch.cat([ + surf_embeds + surf_p_embeds, + edge_embeds + edge_p_embeds, + vert_p_embeds + ], dim=1) else: - # down - for down_block in self.down_blocks: - sample = down_block(sample)[0] + embeds = torch.cat([ + surf_p_embeds, + edge_p_embeds, + vert_p_embeds + ], dim=1) - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - return sample + output = self.transformer(embeds, src_key_padding_mask=mask) + return output -# 2. B-rep特征处理 -class BRepFeatureExtractor: - def __init__(self, config): - self.encoder = Encoder1D( - in_channels=config.in_channels, # 根据特征维度设置 - out_channels=config.out_channels, - block_out_channels=config.block_out_channels, - layers_per_block=config.layers_per_block +class SDFTransformer(nn.Module): + """SDF Transformer编码器""" + def __init__(self, embed_dim: int = 768, num_layers: int = 6): + super().__init__() + layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=8, + dim_feedforward=1024, + dropout=0.1, + batch_first=True, + norm_first=False # 修改这里:设置为False ) - - def extract_face_features(self, face): - """提取面的特征""" - features = [] - try: - # 基本几何特征 - center = face.center_point - normal = face.normal_vector - - # 边界特征 - bounds = face.bounds() - - # 曲面类型特征 - surface_type = face.surface_type - - # 组合特征 - feature = np.concatenate([ - center, # [3] - normal, # [3] - bounds.flatten(), - [surface_type] # 可以用one-hot编码 - ]) - features.append(feature) - - except Exception as e: - print(f"Error extracting face features: {e}") - - return np.array(features) + self.transformer = nn.TransformerEncoder(layer, num_layers) - def extract_edge_features(self, edge): - """提取边的特征""" - features = [] - try: - # 采样点 - points = self.sample_points_on_edge(edge) - - for point in points: - # 位置 - pos = point.coordinates - # 切向量 - tangent = point.tangent - # 曲率 - curvature = point.curvature - - point_feature = np.concatenate([ - pos, # [3] - tangent, # [3] - [curvature] # [1] - ]) - features.append(point_feature) - - except Exception as e: - print(f"Error extracting edge features: {e}") - - return np.array(features) + def forward(self, x, mask=None): + return self.transformer(x, src_key_padding_mask=mask) - @staticmethod - def sample_points_on_edge(edge, num_points=32): - """在边上均匀采样点""" - points = [] - try: - length = edge.length() - for i in range(num_points): - t = i / (num_points - 1) - point = edge.point_at(t * length) - points.append(point) - except Exception as e: - print(f"Error sampling points: {e}") - return points +class SDFHead(nn.Module): + """SDF预测头""" + def __init__(self, embed_dim: int = 768*2): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim//2), + nn.LayerNorm(embed_dim//2), + nn.ReLU(), + nn.Linear(embed_dim//2, embed_dim//4), + nn.LayerNorm(embed_dim//4), + nn.ReLU(), + nn.Linear(embed_dim//4, 1), + nn.Tanh() + ) -class BRepDataProcessor: - def __init__(self, feature_extractor): - self.feature_extractor = feature_extractor - - def process_brep(self, brep_model): - """处理单个B-rep模型""" - try: - # 1. 提取面特征 - face_features = [] - for face in brep_model.faces: - feat = self.feature_extractor.extract_face_features(face) - face_features.append(feat) - - # 2. 提取边特征 - edge_features = [] - for edge in brep_model.edges: - feat = self.feature_extractor.extract_edge_features(edge) - edge_features.append(feat) - - # 3. 组织数据结构 - return { - 'face_features': torch.tensor(face_features), - 'edge_features': torch.tensor(edge_features), - 'topology': self.extract_topology(brep_model) - } - - except Exception as e: - print(f"Error processing B-rep: {e}") - return None - - def extract_topology(self, brep_model): - """提取拓扑关系""" - # 面-边关系矩阵 - face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges))) - # 填充邻接关系 - for i, face in enumerate(brep_model.faces): - for j, edge in enumerate(brep_model.edges): - if edge in face.edges: - face_edge_adj[i,j] = 1 - return face_edge_adj + def forward(self, x): + return self.mlp(x) -# 3. 主编码器 -class BRepEncoder: - def __init__(self, config): - self.processor = BRepDataProcessor( - BRepFeatureExtractor(config) +class BRepToSDF(nn.Module): + def __init__( + self, + brep_feature_dim: int = 48, + use_cf: bool = True, + embed_dim: int = 768, + latent_dim: int = 256 + ): + super().__init__() + self.embed_dim = embed_dim + + # 1. 查询点编码器 + self.query_encoder = nn.Sequential( + nn.Linear(3, embed_dim//4), + nn.LayerNorm(embed_dim//4), + nn.ReLU(), + nn.Linear(embed_dim//4, embed_dim//2), + nn.LayerNorm(embed_dim//2), + nn.ReLU(), + nn.Linear(embed_dim//2, embed_dim) ) - self.encoder = Encoder1D(**config.encoder_params) - def encode(self, brep_model): - """编码B-rep模型""" - try: - # 1. 处理原始数据 - processed_data = self.processor.process_brep(brep_model) - if processed_data is None: - return None - - # 2. 特征编码 - face_features = self.encoder(processed_data['face_features']) - edge_features = self.encoder(processed_data['edge_features']) - - # 3. 组合特征 - combined_features = self.combine_features( - face_features, - edge_features, - processed_data['topology'] - ) - - return combined_features - - except Exception as e: - print(f"Error encoding B-rep: {e}") - return None - - def combine_features(self, face_features, edge_features, topology): - """组合不同类型的特征""" - # 可以使用图神经网络或者注意力机制来组合特征 - combined = torch.cat([ - face_features.mean(dim=1), # 全局面特征 - edge_features.mean(dim=1), # 全局边特征 - topology.flatten() # 拓扑信息 - ], dim=-1) - return combined \ No newline at end of file + # 2. B-rep特征编码器 + self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) + + # 3. 特征融合Transformer + self.transformer = SDFTransformer( + embed_dim=embed_dim, + num_layers=6 + ) + + # 4. SDF预测头 + self.sdf_head = SDFHead(embed_dim=embed_dim*2) + + def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): + """ + Args: + surf_z: 表面特征 [B, N, 48] + edge_z: 边特征 [B, M, 12] + surf_p: 表面点 [B, N, 6] + edge_p: 边点 [B, M, 6] + vert_p: 顶点点 [B, K, 6] + query_points: 查询点 [B, Q, 3] + mask: 注意力掩码 + Returns: + sdf: [B, Q, 1] + """ + B, Q, _ = query_points.shape + + # 1. B-rep特征嵌入 + brep_features = self.feature_embedder( + surf_z, edge_z, surf_p, edge_p, vert_p, mask + ) # [B, N+M+K, embed_dim] + + # 2. 查询点编码 + query_features = self.query_encoder(query_points) # [B, Q, embed_dim] + + # 3. 提取全局特征 + global_features = brep_features.mean(dim=1) # [B, embed_dim] + + # 4. 为每个查询点准备特征 + expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] + + # 5. 连接查询点特征和全局特征 + combined_features = torch.cat([ + expanded_features, # [B, Q, embed_dim] + query_features # [B, Q, embed_dim] + ], dim=-1) # [B, Q, embed_dim*2] + + # 6. SDF预测 + sdf = self.sdf_head(combined_features) # [B, Q, 1] + + return sdf + +def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): + """SDF损失函数""" + # L1损失 + l1_loss = F.l1_loss(pred_sdf, gt_sdf) + + # 梯度约束损失 + grad = torch.autograd.grad( + pred_sdf.sum(), + points, + create_graph=True + )[0] + grad_constraint = F.mse_loss( + torch.norm(grad, dim=-1), + torch.ones_like(pred_sdf.squeeze(-1)) + ) + + return l1_loss + grad_weight * grad_constraint + +def main(): + # 初始化模型 + model = BRepToSDF( + brep_feature_dim=48, + use_cf=True, + embed_dim=768, + latent_dim=256 + ) + + # 示例输入 + batch_size = 4 + num_surfs = 10 + num_edges = 20 + num_verts = 8 + num_queries = 1000 + + # 生成示例数据 + surf_z = torch.randn(batch_size, num_surfs, 48) + edge_z = torch.randn(batch_size, num_edges, 12) + surf_p = torch.randn(batch_size, num_surfs, 6) + edge_p = torch.randn(batch_size, num_edges, 6) + vert_p = torch.randn(batch_size, num_verts, 6) + query_points = torch.randn(batch_size, num_queries, 3) + + # 前向传播 + sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) + print(f"Output SDF shape: {sdf.shape}") + +if __name__ == "__main__": + main() \ No newline at end of file