Browse Source

feat: 模型架构

final
mckay 7 months ago
parent
commit
c47dd74796
  1. 0
      networks/__init__.py
  2. 383
      networks/decoder.py
  3. 109
      networks/deep_sdf_decoder.py
  4. 356
      networks/encoder.py
  5. 127
      networks/network.py

0
networks/__init__.py

383
networks/decoder.py

@ -0,0 +1,383 @@
import math
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
class Decoder1D(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
):
'''
这是第一阶段的解码器用于处理B-rep特征
包含三个主要部分
conv_in: 输入卷积层处理初始特征
mid_block: 中间处理块
up_blocks: 上采样块序列
支持梯度检查点功能gradient checkpointing以节省内存
输出维度: [B, C, L]
# NOTE:
1. 移除了分片(slicing)和平铺(tiling)功能
2. 直接使用mode()而不是sample()获取潜在向量
3. 简化了编码过程只保留核心功能
4. 返回确定性的潜在向量而不是分布
'''
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv1d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlock1D(
in_channels=prev_output_channel,
out_channels=output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None):
sample = z
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
# sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class SDFDecoder(nn.Module):
def __init__(
self,
latent_size,
dims,
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=None,
use_tanh=False,
latent_dropout=False,
):
'''
这是第二阶段的解码器用于生成SDF值
使用多层MLP结构
特点
支持在不同层注入latent信息通过latent_in参数
可以在每层添加xyz坐标通过xyz_in_all参数
支持权重归一化和dropout
输入维度: [N, latent_size+3]
输出维度: [N, 1]
'''
super(SDFDecoder, self).__init__()
def make_sequence():
return []
dims = [latent_size + 3] + dims + [1]
self.num_layers = len(dims)
self.norm_layers = norm_layers
self.latent_in = latent_in
self.latent_dropout = latent_dropout
if self.latent_dropout:
self.lat_dp = nn.Dropout(0.2)
self.xyz_in_all = xyz_in_all
self.weight_norm = weight_norm
for layer in range(0, self.num_layers - 1):
if layer + 1 in latent_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
if self.xyz_in_all and layer != self.num_layers - 2:
out_dim -= 3
if weight_norm and layer in self.norm_layers:
setattr(
self,
"lin" + str(layer),
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
)
else:
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
if (
(not weight_norm)
and self.norm_layers is not None
and layer in self.norm_layers
):
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
self.use_tanh = use_tanh
if use_tanh:
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.dropout_prob = dropout_prob
self.dropout = dropout
self.th = nn.Tanh()
# input: N x (L+3)
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x
class BRep2SdfDecoder(nn.Module):
def __init__(
self,
latent_size=256,
feature_dims=[512, 512, 256, 128], # 特征解码器维度
sdf_dims=[512, 512, 512, 512], # SDF解码器维度
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
norm_type="group",
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=True,
use_tanh=True,
):
super().__init__()
# 1. 特征解码器 (使用Decoder1D结构)
self.feature_decoder = Decoder1D(
in_channels=latent_size,
out_channels=feature_dims[-1],
up_block_types=up_block_types,
block_out_channels=feature_dims,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
norm_type=norm_type
)
# 2. SDF解码器 (使用原始Decoder结构)
self.sdf_decoder = SDFDecoder(
latent_size=feature_dims[-1], # 使用特征解码器的输出维度
dims=sdf_dims,
dropout=dropout,
dropout_prob=dropout_prob,
norm_layers=norm_layers,
latent_in=latent_in,
weight_norm=weight_norm,
xyz_in_all=xyz_in_all,
use_tanh=use_tanh,
latent_dropout=False
)
# 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式)
self.feature_transform = nn.Sequential(
nn.Linear(feature_dims[-1], feature_dims[-1]),
nn.LayerNorm(feature_dims[-1]),
nn.SiLU()
)
def forward(self, latent, query_points, latent_embeds=None):
"""
Args:
latent: [B, C, L] B-rep特征
query_points: [B, N, 3] 查询点
latent_embeds: 可选的条件嵌入
Returns:
sdf: [B, N, 1] SDF值
"""
# 1. 特征解码
features = self.feature_decoder(latent, latent_embeds) # [B, C, L]
# 2. 特征转换
B, C, L = features.shape
features = features.permute(0, 2, 1) # [B, L, C]
features = self.feature_transform(features) # [B, L, C]
# 3. 准备SDF解码器输入
_, N, _ = query_points.shape
features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C]
query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3]
# 4. 合并特征和坐标
sdf_input = torch.cat([
features.reshape(B*N*L, -1), # [B*N*L, C]
query_points.reshape(B*N*L, -1) # [B*N*L, 3]
], dim=-1)
# 5. SDF生成
sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1]
sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1]
# 6. 聚合多尺度SDF
sdf = sdf.mean(dim=2) # [B, N, 1]
return sdf
# 使用示例
if __name__ == "__main__":
# 创建模型
decoder = BRepDecoder(
latent_size=256,
feature_dims=[512, 256, 128, 64],
sdf_dims=[512, 512, 512, 512],
up_block_types=("UpDecoderBlock2D",),
layers_per_block=2,
norm_num_groups=32,
dropout=None,
dropout_prob=0.0,
norm_layers=[0, 1, 2, 3],
latent_in=[4],
weight_norm=True,
xyz_in_all=True,
use_tanh=True
)
# 测试数据
batch_size = 4
seq_len = 32
num_points = 1000
latent = torch.randn(batch_size, 256, seq_len)
query_points = torch.randn(batch_size, num_points, 3)
latent_embeds = torch.randn(batch_size, 256)
# 前向传播
sdf = decoder(latent, query_points, latent_embeds)
print(f"Input latent shape: {latent.shape}")
print(f"Query points shape: {query_points.shape}")
print(f"Output SDF shape: {sdf.shape}")

109
networks/deep_sdf_decoder.py

@ -1,109 +0,0 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import torch.nn as nn
import torch
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(
self,
latent_size,
dims,
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=None,
use_tanh=False,
latent_dropout=False,
):
super(Decoder, self).__init__()
def make_sequence():
return []
dims = [latent_size + 3] + dims + [1]
self.num_layers = len(dims)
self.norm_layers = norm_layers
self.latent_in = latent_in
self.latent_dropout = latent_dropout
if self.latent_dropout:
self.lat_dp = nn.Dropout(0.2)
self.xyz_in_all = xyz_in_all
self.weight_norm = weight_norm
for layer in range(0, self.num_layers - 1):
if layer + 1 in latent_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
if self.xyz_in_all and layer != self.num_layers - 2:
out_dim -= 3
if weight_norm and layer in self.norm_layers:
setattr(
self,
"lin" + str(layer),
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
)
else:
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
if (
(not weight_norm)
and self.norm_layers is not None
and layer in self.norm_layers
):
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
self.use_tanh = use_tanh
if use_tanh:
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.dropout_prob = dropout_prob
self.dropout = dropout
self.th = nn.Tanh()
# input: N x (L+3)
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x

356
networks/encoder.py

@ -0,0 +1,356 @@
import math
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d
from diffusers.models.attention_processor import SpatialNorm
'''
# NOTE:
移除了分片(slicing)和平铺(tiling)功能
直接使用mode()而不是sample()获取潜在向量
简化了编码过程只保留核心功能
返回确定性的潜在向量而不是分布
'''
# 1. 基础网络组件
class Embedder(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self._init_embeddings()
def _init_embeddings(self):
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
def forward(self, x):
return self.embed(x)
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states
class Encoder1D(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock1D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv1d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x):
sample = x
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)[0]
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
# 2. B-rep特征处理
class BRepFeatureExtractor:
def __init__(self, config):
self.encoder = Encoder1D(
in_channels=config.in_channels, # 根据特征维度设置
out_channels=config.out_channels,
block_out_channels=config.block_out_channels,
layers_per_block=config.layers_per_block
)
def extract_face_features(self, face):
"""提取面的特征"""
features = []
try:
# 基本几何特征
center = face.center_point
normal = face.normal_vector
# 边界特征
bounds = face.bounds()
# 曲面类型特征
surface_type = face.surface_type
# 组合特征
feature = np.concatenate([
center, # [3]
normal, # [3]
bounds.flatten(),
[surface_type] # 可以用one-hot编码
])
features.append(feature)
except Exception as e:
print(f"Error extracting face features: {e}")
return np.array(features)
def extract_edge_features(self, edge):
"""提取边的特征"""
features = []
try:
# 采样点
points = self.sample_points_on_edge(edge)
for point in points:
# 位置
pos = point.coordinates
# 切向量
tangent = point.tangent
# 曲率
curvature = point.curvature
point_feature = np.concatenate([
pos, # [3]
tangent, # [3]
[curvature] # [1]
])
features.append(point_feature)
except Exception as e:
print(f"Error extracting edge features: {e}")
return np.array(features)
@staticmethod
def sample_points_on_edge(edge, num_points=32):
"""在边上均匀采样点"""
points = []
try:
length = edge.length()
for i in range(num_points):
t = i / (num_points - 1)
point = edge.point_at(t * length)
points.append(point)
except Exception as e:
print(f"Error sampling points: {e}")
return points
class BRepDataProcessor:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def process_brep(self, brep_model):
"""处理单个B-rep模型"""
try:
# 1. 提取面特征
face_features = []
for face in brep_model.faces:
feat = self.feature_extractor.extract_face_features(face)
face_features.append(feat)
# 2. 提取边特征
edge_features = []
for edge in brep_model.edges:
feat = self.feature_extractor.extract_edge_features(edge)
edge_features.append(feat)
# 3. 组织数据结构
return {
'face_features': torch.tensor(face_features),
'edge_features': torch.tensor(edge_features),
'topology': self.extract_topology(brep_model)
}
except Exception as e:
print(f"Error processing B-rep: {e}")
return None
def extract_topology(self, brep_model):
"""提取拓扑关系"""
# 面-边关系矩阵
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges)))
# 填充邻接关系
for i, face in enumerate(brep_model.faces):
for j, edge in enumerate(brep_model.edges):
if edge in face.edges:
face_edge_adj[i,j] = 1
return face_edge_adj
# 3. 主编码器
class BRepEncoder:
def __init__(self, config):
self.processor = BRepDataProcessor(
BRepFeatureExtractor(config)
)
self.encoder = Encoder1D(**config.encoder_params)
def encode(self, brep_model):
"""编码B-rep模型"""
try:
# 1. 处理原始数据
processed_data = self.processor.process_brep(brep_model)
if processed_data is None:
return None
# 2. 特征编码
face_features = self.encoder(processed_data['face_features'])
edge_features = self.encoder(processed_data['edge_features'])
# 3. 组合特征
combined_features = self.combine_features(
face_features,
edge_features,
processed_data['topology']
)
return combined_features
except Exception as e:
print(f"Error encoding B-rep: {e}")
return None
def combine_features(self, face_features, edge_features, topology):
"""组合不同类型的特征"""
# 可以使用图神经网络或者注意力机制来组合特征
combined = torch.cat([
face_features.mean(dim=1), # 全局面特征
edge_features.mean(dim=1), # 全局边特征
topology.flatten() # 拓扑信息
], dim=-1)
return combined

127
networks/network.py

@ -0,0 +1,127 @@
import torch
import torch.nn as nn
from encoder import BRepEncoder
from decoder import BRep2SdfDecoder
class BRep2SDF(nn.Module):
def __init__(
self,
# 编码器参数
in_channels=3,
latent_size=256,
encoder_block_out_channels=(512, 256, 128, 64),
# 解码器参数
decoder_feature_dims=(512, 256, 128, 64),
sdf_dims=(512, 512, 512, 512),
# 共享参数
layers_per_block=2,
norm_num_groups=32,
# SDF特定参数
dropout=None,
dropout_prob=0.0,
norm_layers=(0, 1, 2, 3),
latent_in=(4,),
weight_norm=True,
xyz_in_all=True,
use_tanh=True,
):
super().__init__()
# 1. 编码器配置
encoder_config = type('Config', (), {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'encoder_params': {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
}
})()
# 2. 解码器配置
decoder_config = {
'latent_size': latent_size,
'feature_dims': decoder_feature_dims,
'sdf_dims': sdf_dims,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'dropout': dropout,
'dropout_prob': dropout_prob,
'norm_layers': norm_layers,
'latent_in': latent_in,
'weight_norm': weight_norm,
'xyz_in_all': xyz_in_all,
'use_tanh': use_tanh,
}
# 3. 创建编码器和解码器
self.encoder = BRepEncoder(encoder_config)
self.decoder = BRep2SdfDecoder(**decoder_config)
def encode(self, brep_model):
"""编码B-rep模型为潜在特征"""
return self.encoder.encode(brep_model)
def decode(self, latent, query_points, latent_embeds=None):
"""从潜在特征解码SDF值"""
return self.decoder(latent, query_points, latent_embeds)
def forward(self, brep_model, query_points):
"""完整的前向传播过程"""
# 1. 编码B-rep模型
latent = self.encode(brep_model)
if latent is None:
return None
# 2. 解码SDF值
sdf = self.decode(latent, query_points)
return sdf
# 使用示例
if __name__ == "__main__":
# 创建模型
model = BRep2SDF(
in_channels=3,
latent_size=256,
encoder_block_out_channels=(512, 256, 128, 64),
decoder_feature_dims=(512, 256, 128, 64),
sdf_dims=(512, 512, 512, 512),
layers_per_block=2,
norm_num_groups=32,
)
# 测试数据
batch_size = 4
seq_len = 32
num_points = 1000
# 模拟B-rep模型数据
class MockBRep:
def __init__(self):
self.faces = [MockFace() for _ in range(10)]
self.edges = [MockEdge() for _ in range(20)]
class MockFace:
def __init__(self):
self.center_point = torch.randn(3)
self.normal_vector = torch.randn(3)
self.surface_type = 0
self.edges = []
class MockEdge:
def __init__(self):
self.length = lambda: 1.0
self.point_at = lambda t: torch.randn(3)
brep_model = MockBRep()
query_points = torch.randn(batch_size, num_points, 3)
# 前向传播
sdf = model(brep_model, query_points)
if sdf is not None:
print(f"Output SDF shape: {sdf.shape}")
Loading…
Cancel
Save