Browse Source

encode.py可以运行

final
王琛涵 7 months ago
parent
commit
ddc4808d02
  1. 624
      networks/encoder.py

624
networks/encoder.py

@ -1,356 +1,358 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config class ResConvBlock(nn.Module):
from diffusers.utils import BaseOutput, is_torch_version """残差卷积块"""
from diffusers.utils.accelerate_utils import apply_forward_hook def __init__(self, in_channels: int, mid_channels: int, out_channels: int):
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
'''
# NOTE:
移除了分片(slicing)和平铺(tiling)功能
直接使用mode()而不是sample()获取潜在向量
简化了编码过程只保留核心功能
返回确定性的潜在向量而不是分布
'''
# 1. 基础网络组件
class Embedder(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__() super().__init__()
self.embed = nn.Embedding(vocab_size, d_model) self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1)
self._init_embeddings() self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1)
def _init_embeddings(self): self.norm2 = nn.GroupNorm(32, out_channels)
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") self.act = nn.SiLU()
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x): def forward(self, x):
return self.embed(x) residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return self.act(x + residual)
class UpBlock1D(nn.Module): class SelfAttention1d(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): """一维自注意力层"""
def __init__(self, channels: int, num_head_channels: int):
super().__init__() super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels self.num_heads = channels // num_head_channels
self.scale = num_head_channels ** -0.5
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels), self.qkv = nn.Conv1d(channels, channels * 3, 1)
ResConvBlock(mid_channels, mid_channels, mid_channels), self.proj = nn.Conv1d(channels, channels, 1)
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
def forward(self, x):
b, c, l = x.shape
qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l)
q, k, v = qkv.unbind(1)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).reshape(b, c, l)
x = self.proj(x)
return x
class UNetMidBlock1D(nn.Module): class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): """U-Net中间块"""
def __init__(self, in_channels: int, mid_channels: int):
super().__init__() super().__init__()
self.resnets = nn.ModuleList([
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, in_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels), ])
ResConvBlock(mid_channels, mid_channels, mid_channels), self.attentions = nn.ModuleList([
ResConvBlock(mid_channels, mid_channels, out_channels), SelfAttention1d(mid_channels, mid_channels // 32)
] for _ in range(3)
attentions = [ ])
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(self, x):
for attn, resnet in zip(self.attentions, self.resnets): for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states) x = resnet(x)
hidden_states = attn(hidden_states) x = attn(x)
return x
return hidden_states
class Encoder1D(nn.Module): class Encoder1D(nn.Module):
"""一维编码器"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 256,
down_block_types=("DownEncoderBlock1D",), block_out_channels: Tuple[int] = (64, 128, 256),
block_out_channels=(64,), layers_per_block: int = 2,
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1)
self.conv_in = torch.nn.Conv1d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
in_ch = block_out_channels[0]
# down for out_ch in block_out_channels:
output_channel = block_out_channels[0] block = []
for i, down_block_type in enumerate(down_block_types): for _ in range(layers_per_block):
input_channel = output_channel block.append(ResConvBlock(in_ch, out_ch, out_ch))
output_channel = block_out_channels[i] in_ch = out_ch
is_final_block = i == len(block_out_channels) - 1 if out_ch != block_out_channels[-1]:
block.append(nn.AvgPool1d(2))
down_block = get_down_block( self.down_blocks.append(nn.Sequential(*block))
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock1D( self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1], mid_channels=block_out_channels[-1],
) )
# out self.conv_out = nn.Sequential(
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) nn.GroupNorm(32, block_out_channels[-1]),
self.conv_act = nn.SiLU() nn.SiLU(),
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
conv_out_channels = 2 * out_channels if double_z else out_channels )
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x): def forward(self, x):
sample = x x = self.conv_in(x)
sample = self.conv_in(sample) for block in self.down_blocks:
x = block(x)
x = self.mid_block(x)
x = self.conv_out(x)
return x
class BRepFeatureEmbedder(nn.Module):
"""B-rep特征嵌入器"""
def __init__(self, use_cf: bool = True):
super().__init__()
self.embed_dim = 768
self.use_cf = use_cf
if self.training and self.gradient_checkpointing: layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
def create_custom_forward(module): nhead=12,
def custom_forward(*inputs): norm_first=False,
return module(*inputs) dim_feedforward=1024,
dropout=0.1
return custom_forward )
self.transformer = nn.TransformerEncoder(
# down layer,
if is_torch_version(">=", "1.11.0"): num_layers=12,
for down_block in self.down_blocks: norm=nn.LayerNorm(self.embed_dim),
sample = torch.utils.checkpoint.checkpoint( enable_nested_tensor=False # 添加这个参数
create_custom_forward(down_block), sample, use_reentrant=False )
)
self.surfz_embed = nn.Sequential(
# middle nn.Linear(3*16, self.embed_dim),
sample = torch.utils.checkpoint.checkpoint( nn.LayerNorm(self.embed_dim),
create_custom_forward(self.mid_block), sample, use_reentrant=False nn.SiLU(),
) nn.Linear(self.embed_dim, self.embed_dim),
else: )
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) self.edgez_embed = nn.Sequential(
# middle nn.Linear(3*4, self.embed_dim),
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgep_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.vertp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None):
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z)
edge_embeds = self.edgez_embed(edge_z)
# 点嵌入
surf_p_embeds = self.surfp_embed(surf_p)
edge_p_embeds = self.edgep_embed(edge_p)
vert_p_embeds = self.vertp_embed(vert_p)
# 组合所有嵌入
if self.use_cf:
embeds = torch.cat([
surf_embeds + surf_p_embeds,
edge_embeds + edge_p_embeds,
vert_p_embeds
], dim=1)
else: else:
# down embeds = torch.cat([
for down_block in self.down_blocks: surf_p_embeds,
sample = down_block(sample)[0] edge_p_embeds,
vert_p_embeds
], dim=1)
# middle output = self.transformer(embeds, src_key_padding_mask=mask)
sample = self.mid_block(sample) return output
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
# 2. B-rep特征处理 class SDFTransformer(nn.Module):
class BRepFeatureExtractor: """SDF Transformer编码器"""
def __init__(self, config): def __init__(self, embed_dim: int = 768, num_layers: int = 6):
self.encoder = Encoder1D( super().__init__()
in_channels=config.in_channels, # 根据特征维度设置 layer = nn.TransformerEncoderLayer(
out_channels=config.out_channels, d_model=embed_dim,
block_out_channels=config.block_out_channels, nhead=8,
layers_per_block=config.layers_per_block dim_feedforward=1024,
dropout=0.1,
batch_first=True,
norm_first=False # 修改这里:设置为False
) )
self.transformer = nn.TransformerEncoder(layer, num_layers)
def extract_face_features(self, face):
"""提取面的特征"""
features = []
try:
# 基本几何特征
center = face.center_point
normal = face.normal_vector
# 边界特征
bounds = face.bounds()
# 曲面类型特征
surface_type = face.surface_type
# 组合特征
feature = np.concatenate([
center, # [3]
normal, # [3]
bounds.flatten(),
[surface_type] # 可以用one-hot编码
])
features.append(feature)
except Exception as e:
print(f"Error extracting face features: {e}")
return np.array(features)
def extract_edge_features(self, edge): def forward(self, x, mask=None):
"""提取边的特征""" return self.transformer(x, src_key_padding_mask=mask)
features = []
try:
# 采样点
points = self.sample_points_on_edge(edge)
for point in points:
# 位置
pos = point.coordinates
# 切向量
tangent = point.tangent
# 曲率
curvature = point.curvature
point_feature = np.concatenate([
pos, # [3]
tangent, # [3]
[curvature] # [1]
])
features.append(point_feature)
except Exception as e:
print(f"Error extracting edge features: {e}")
return np.array(features)
@staticmethod class SDFHead(nn.Module):
def sample_points_on_edge(edge, num_points=32): """SDF预测头"""
"""在边上均匀采样点""" def __init__(self, embed_dim: int = 768*2):
points = [] super().__init__()
try: self.mlp = nn.Sequential(
length = edge.length() nn.Linear(embed_dim, embed_dim//2),
for i in range(num_points): nn.LayerNorm(embed_dim//2),
t = i / (num_points - 1) nn.ReLU(),
point = edge.point_at(t * length) nn.Linear(embed_dim//2, embed_dim//4),
points.append(point) nn.LayerNorm(embed_dim//4),
except Exception as e: nn.ReLU(),
print(f"Error sampling points: {e}") nn.Linear(embed_dim//4, 1),
return points nn.Tanh()
)
class BRepDataProcessor: def forward(self, x):
def __init__(self, feature_extractor): return self.mlp(x)
self.feature_extractor = feature_extractor
def process_brep(self, brep_model):
"""处理单个B-rep模型"""
try:
# 1. 提取面特征
face_features = []
for face in brep_model.faces:
feat = self.feature_extractor.extract_face_features(face)
face_features.append(feat)
# 2. 提取边特征
edge_features = []
for edge in brep_model.edges:
feat = self.feature_extractor.extract_edge_features(edge)
edge_features.append(feat)
# 3. 组织数据结构
return {
'face_features': torch.tensor(face_features),
'edge_features': torch.tensor(edge_features),
'topology': self.extract_topology(brep_model)
}
except Exception as e:
print(f"Error processing B-rep: {e}")
return None
def extract_topology(self, brep_model):
"""提取拓扑关系"""
# 面-边关系矩阵
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges)))
# 填充邻接关系
for i, face in enumerate(brep_model.faces):
for j, edge in enumerate(brep_model.edges):
if edge in face.edges:
face_edge_adj[i,j] = 1
return face_edge_adj
# 3. 主编码器 class BRepToSDF(nn.Module):
class BRepEncoder: def __init__(
def __init__(self, config): self,
self.processor = BRepDataProcessor( brep_feature_dim: int = 48,
BRepFeatureExtractor(config) use_cf: bool = True,
embed_dim: int = 768,
latent_dim: int = 256
):
super().__init__()
self.embed_dim = embed_dim
# 1. 查询点编码器
self.query_encoder = nn.Sequential(
nn.Linear(3, embed_dim//4),
nn.LayerNorm(embed_dim//4),
nn.ReLU(),
nn.Linear(embed_dim//4, embed_dim//2),
nn.LayerNorm(embed_dim//2),
nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim)
) )
self.encoder = Encoder1D(**config.encoder_params)
def encode(self, brep_model): # 2. B-rep特征编码器
"""编码B-rep模型""" self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf)
try:
# 1. 处理原始数据 # 3. 特征融合Transformer
processed_data = self.processor.process_brep(brep_model) self.transformer = SDFTransformer(
if processed_data is None: embed_dim=embed_dim,
return None num_layers=6
)
# 2. 特征编码
face_features = self.encoder(processed_data['face_features']) # 4. SDF预测头
edge_features = self.encoder(processed_data['edge_features']) self.sdf_head = SDFHead(embed_dim=embed_dim*2)
# 3. 组合特征 def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None):
combined_features = self.combine_features( """
face_features, Args:
edge_features, surf_z: 表面特征 [B, N, 48]
processed_data['topology'] edge_z: 边特征 [B, M, 12]
) surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
return combined_features vert_p: 顶点点 [B, K, 6]
query_points: 查询点 [B, Q, 3]
except Exception as e: mask: 注意力掩码
print(f"Error encoding B-rep: {e}") Returns:
return None sdf: [B, Q, 1]
"""
def combine_features(self, face_features, edge_features, topology): B, Q, _ = query_points.shape
"""组合不同类型的特征"""
# 可以使用图神经网络或者注意力机制来组合特征 # 1. B-rep特征嵌入
combined = torch.cat([ brep_features = self.feature_embedder(
face_features.mean(dim=1), # 全局面特征 surf_z, edge_z, surf_p, edge_p, vert_p, mask
edge_features.mean(dim=1), # 全局边特征 ) # [B, N+M+K, embed_dim]
topology.flatten() # 拓扑信息
], dim=-1) # 2. 查询点编码
return combined query_features = self.query_encoder(query_points) # [B, Q, embed_dim]
# 3. 提取全局特征
global_features = brep_features.mean(dim=1) # [B, embed_dim]
# 4. 为每个查询点准备特征
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim]
# 5. 连接查询点特征和全局特征
combined_features = torch.cat([
expanded_features, # [B, Q, embed_dim]
query_features # [B, Q, embed_dim]
], dim=-1) # [B, Q, embed_dim*2]
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
return sdf
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True
)[0]
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
return l1_loss + grad_weight * grad_constraint
def main():
# 初始化模型
model = BRepToSDF(
brep_feature_dim=48,
use_cf=True,
embed_dim=768,
latent_dim=256
)
# 示例输入
batch_size = 4
num_surfs = 10
num_edges = 20
num_verts = 8
num_queries = 1000
# 生成示例数据
surf_z = torch.randn(batch_size, num_surfs, 48)
edge_z = torch.randn(batch_size, num_edges, 12)
surf_p = torch.randn(batch_size, num_surfs, 6)
edge_p = torch.randn(batch_size, num_edges, 6)
vert_p = torch.randn(batch_size, num_verts, 6)
query_points = torch.randn(batch_size, num_queries, 3)
# 前向传播
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
print(f"Output SDF shape: {sdf.shape}")
if __name__ == "__main__":
main()
Loading…
Cancel
Save