Browse Source

encode.py可以运行

final
王琛涵 7 months ago
parent
commit
ddc4808d02
  1. 608
      networks/encoder.py

608
networks/encoder.py

@ -1,356 +1,358 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
'''
# NOTE:
移除了分片(slicing)和平铺(tiling)功能
直接使用mode()而不是sample()获取潜在向量
简化了编码过程只保留核心功能
返回确定性的潜在向量而不是分布
'''
# 1. 基础网络组件
class Embedder(nn.Module):
def __init__(self, vocab_size, d_model):
class ResConvBlock(nn.Module):
"""残差卷积块"""
def __init__(self, in_channels: int, mid_channels: int, out_channels: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self._init_embeddings()
self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1)
self.norm2 = nn.GroupNorm(32, out_channels)
self.act = nn.SiLU()
def _init_embeddings(self):
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x):
return self.embed(x)
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
return self.act(x + residual)
class SelfAttention1d(nn.Module):
"""一维自注意力层"""
def __init__(self, channels: int, num_head_channels: int):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
self.num_heads = channels // num_head_channels
self.scale = num_head_channels ** -0.5
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.proj = nn.Conv1d(channels, channels, 1)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, x):
b, c, l = x.shape
qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l)
q, k, v = qkv.unbind(1)
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).reshape(b, c, l)
x = self.proj(x)
return x
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
"""U-Net中间块"""
def __init__(self, in_channels: int, mid_channels: int):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
resnets = [
self.resnets = nn.ModuleList([
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states
ResConvBlock(mid_channels, mid_channels, in_channels),
])
self.attentions = nn.ModuleList([
SelfAttention1d(mid_channels, mid_channels // 32)
for _ in range(3)
])
def forward(self, x):
for attn, resnet in zip(self.attentions, self.resnets):
x = resnet(x)
x = attn(x)
return x
class Encoder1D(nn.Module):
"""一维编码器"""
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock1D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
in_channels: int = 3,
out_channels: int = 256,
block_out_channels: Tuple[int] = (64, 128, 256),
layers_per_block: int = 2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv1d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
in_ch = block_out_channels[0]
for out_ch in block_out_channels:
block = []
for _ in range(layers_per_block):
block.append(ResConvBlock(in_ch, out_ch, out_ch))
in_ch = out_ch
if out_ch != block_out_channels[-1]:
block.append(nn.AvgPool1d(2))
self.down_blocks.append(nn.Sequential(*block))
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
self.conv_out = nn.Sequential(
nn.GroupNorm(32, block_out_channels[-1]),
nn.SiLU(),
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
)
def forward(self, x):
sample = x
sample = self.conv_in(sample)
x = self.conv_in(x)
for block in self.down_blocks:
x = block(x)
x = self.mid_block(x)
x = self.conv_out(x)
return x
class BRepFeatureEmbedder(nn.Module):
"""B-rep特征嵌入器"""
def __init__(self, use_cf: bool = True):
super().__init__()
self.embed_dim = 768
self.use_cf = use_cf
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=12,
norm_first=False,
dim_feedforward=1024,
dropout=0.1
)
self.transformer = nn.TransformerEncoder(
layer,
num_layers=12,
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False # 添加这个参数
)
if self.training and self.gradient_checkpointing:
self.surfz_embed = nn.Sequential(
nn.Linear(3*16, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
self.edgez_embed = nn.Sequential(
nn.Linear(3*4, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
return custom_forward
self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
self.edgep_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
self.vertp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None):
# 特征嵌入
surf_embeds = self.surfz_embed(surf_z)
edge_embeds = self.edgez_embed(edge_z)
# 点嵌入
surf_p_embeds = self.surfp_embed(surf_p)
edge_p_embeds = self.edgep_embed(edge_p)
vert_p_embeds = self.vertp_embed(vert_p)
# 组合所有嵌入
if self.use_cf:
embeds = torch.cat([
surf_embeds + surf_p_embeds,
edge_embeds + edge_p_embeds,
vert_p_embeds
], dim=1)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)[0]
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
# 2. B-rep特征处理
class BRepFeatureExtractor:
def __init__(self, config):
self.encoder = Encoder1D(
in_channels=config.in_channels, # 根据特征维度设置
out_channels=config.out_channels,
block_out_channels=config.block_out_channels,
layers_per_block=config.layers_per_block
embeds = torch.cat([
surf_p_embeds,
edge_p_embeds,
vert_p_embeds
], dim=1)
output = self.transformer(embeds, src_key_padding_mask=mask)
return output
class SDFTransformer(nn.Module):
"""SDF Transformer编码器"""
def __init__(self, embed_dim: int = 768, num_layers: int = 6):
super().__init__()
layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=8,
dim_feedforward=1024,
dropout=0.1,
batch_first=True,
norm_first=False # 修改这里:设置为False
)
self.transformer = nn.TransformerEncoder(layer, num_layers)
def extract_face_features(self, face):
"""提取面的特征"""
features = []
try:
# 基本几何特征
center = face.center_point
normal = face.normal_vector
# 边界特征
bounds = face.bounds()
# 曲面类型特征
surface_type = face.surface_type
# 组合特征
feature = np.concatenate([
center, # [3]
normal, # [3]
bounds.flatten(),
[surface_type] # 可以用one-hot编码
])
features.append(feature)
except Exception as e:
print(f"Error extracting face features: {e}")
return np.array(features)
def extract_edge_features(self, edge):
"""提取边的特征"""
features = []
try:
# 采样点
points = self.sample_points_on_edge(edge)
for point in points:
# 位置
pos = point.coordinates
# 切向量
tangent = point.tangent
# 曲率
curvature = point.curvature
point_feature = np.concatenate([
pos, # [3]
tangent, # [3]
[curvature] # [1]
])
features.append(point_feature)
except Exception as e:
print(f"Error extracting edge features: {e}")
return np.array(features)
@staticmethod
def sample_points_on_edge(edge, num_points=32):
"""在边上均匀采样点"""
points = []
try:
length = edge.length()
for i in range(num_points):
t = i / (num_points - 1)
point = edge.point_at(t * length)
points.append(point)
except Exception as e:
print(f"Error sampling points: {e}")
return points
class BRepDataProcessor:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def process_brep(self, brep_model):
"""处理单个B-rep模型"""
try:
# 1. 提取面特征
face_features = []
for face in brep_model.faces:
feat = self.feature_extractor.extract_face_features(face)
face_features.append(feat)
# 2. 提取边特征
edge_features = []
for edge in brep_model.edges:
feat = self.feature_extractor.extract_edge_features(edge)
edge_features.append(feat)
# 3. 组织数据结构
return {
'face_features': torch.tensor(face_features),
'edge_features': torch.tensor(edge_features),
'topology': self.extract_topology(brep_model)
}
except Exception as e:
print(f"Error processing B-rep: {e}")
return None
def extract_topology(self, brep_model):
"""提取拓扑关系"""
# 面-边关系矩阵
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges)))
# 填充邻接关系
for i, face in enumerate(brep_model.faces):
for j, edge in enumerate(brep_model.edges):
if edge in face.edges:
face_edge_adj[i,j] = 1
return face_edge_adj
# 3. 主编码器
class BRepEncoder:
def __init__(self, config):
self.processor = BRepDataProcessor(
BRepFeatureExtractor(config)
def forward(self, x, mask=None):
return self.transformer(x, src_key_padding_mask=mask)
class SDFHead(nn.Module):
"""SDF预测头"""
def __init__(self, embed_dim: int = 768*2):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim//2),
nn.LayerNorm(embed_dim//2),
nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim//4),
nn.LayerNorm(embed_dim//4),
nn.ReLU(),
nn.Linear(embed_dim//4, 1),
nn.Tanh()
)
def forward(self, x):
return self.mlp(x)
class BRepToSDF(nn.Module):
def __init__(
self,
brep_feature_dim: int = 48,
use_cf: bool = True,
embed_dim: int = 768,
latent_dim: int = 256
):
super().__init__()
self.embed_dim = embed_dim
# 1. 查询点编码器
self.query_encoder = nn.Sequential(
nn.Linear(3, embed_dim//4),
nn.LayerNorm(embed_dim//4),
nn.ReLU(),
nn.Linear(embed_dim//4, embed_dim//2),
nn.LayerNorm(embed_dim//2),
nn.ReLU(),
nn.Linear(embed_dim//2, embed_dim)
)
# 2. B-rep特征编码器
self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf)
# 3. 特征融合Transformer
self.transformer = SDFTransformer(
embed_dim=embed_dim,
num_layers=6
)
# 4. SDF预测头
self.sdf_head = SDFHead(embed_dim=embed_dim*2)
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None):
"""
Args:
surf_z: 表面特征 [B, N, 48]
edge_z: 边特征 [B, M, 12]
surf_p: 表面点 [B, N, 6]
edge_p: 边点 [B, M, 6]
vert_p: 顶点点 [B, K, 6]
query_points: 查询点 [B, Q, 3]
mask: 注意力掩码
Returns:
sdf: [B, Q, 1]
"""
B, Q, _ = query_points.shape
# 1. B-rep特征嵌入
brep_features = self.feature_embedder(
surf_z, edge_z, surf_p, edge_p, vert_p, mask
) # [B, N+M+K, embed_dim]
# 2. 查询点编码
query_features = self.query_encoder(query_points) # [B, Q, embed_dim]
# 3. 提取全局特征
global_features = brep_features.mean(dim=1) # [B, embed_dim]
# 4. 为每个查询点准备特征
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim]
# 5. 连接查询点特征和全局特征
combined_features = torch.cat([
expanded_features, # [B, Q, embed_dim]
query_features # [B, Q, embed_dim]
], dim=-1) # [B, Q, embed_dim*2]
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
return sdf
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1):
"""SDF损失函数"""
# L1损失
l1_loss = F.l1_loss(pred_sdf, gt_sdf)
# 梯度约束损失
grad = torch.autograd.grad(
pred_sdf.sum(),
points,
create_graph=True
)[0]
grad_constraint = F.mse_loss(
torch.norm(grad, dim=-1),
torch.ones_like(pred_sdf.squeeze(-1))
)
self.encoder = Encoder1D(**config.encoder_params)
def encode(self, brep_model):
"""编码B-rep模型"""
try:
# 1. 处理原始数据
processed_data = self.processor.process_brep(brep_model)
if processed_data is None:
return None
# 2. 特征编码
face_features = self.encoder(processed_data['face_features'])
edge_features = self.encoder(processed_data['edge_features'])
# 3. 组合特征
combined_features = self.combine_features(
face_features,
edge_features,
processed_data['topology']
return l1_loss + grad_weight * grad_constraint
def main():
# 初始化模型
model = BRepToSDF(
brep_feature_dim=48,
use_cf=True,
embed_dim=768,
latent_dim=256
)
return combined_features
except Exception as e:
print(f"Error encoding B-rep: {e}")
return None
def combine_features(self, face_features, edge_features, topology):
"""组合不同类型的特征"""
# 可以使用图神经网络或者注意力机制来组合特征
combined = torch.cat([
face_features.mean(dim=1), # 全局面特征
edge_features.mean(dim=1), # 全局边特征
topology.flatten() # 拓扑信息
], dim=-1)
return combined
# 示例输入
batch_size = 4
num_surfs = 10
num_edges = 20
num_verts = 8
num_queries = 1000
# 生成示例数据
surf_z = torch.randn(batch_size, num_surfs, 48)
edge_z = torch.randn(batch_size, num_edges, 12)
surf_p = torch.randn(batch_size, num_surfs, 6)
edge_p = torch.randn(batch_size, num_edges, 6)
vert_p = torch.randn(batch_size, num_verts, 6)
query_points = torch.randn(batch_size, num_queries, 3)
# 前向传播
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points)
print(f"Output SDF shape: {sdf.shape}")
if __name__ == "__main__":
main()
Loading…
Cancel
Save