You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

127 lines
3.9 KiB

import torch
import torch.nn as nn
from encoder import BRepEncoder
from decoder import BRep2SdfDecoder
class BRep2SDF(nn.Module):
def __init__(
self,
# 编码器参数
in_channels=3,
latent_size=256,
encoder_block_out_channels=(512, 256, 128, 64),
# 解码器参数
decoder_feature_dims=(512, 256, 128, 64),
sdf_dims=(512, 512, 512, 512),
# 共享参数
layers_per_block=2,
norm_num_groups=32,
# SDF特定参数
dropout=None,
dropout_prob=0.0,
norm_layers=(0, 1, 2, 3),
latent_in=(4,),
weight_norm=True,
xyz_in_all=True,
use_tanh=True,
):
super().__init__()
# 1. 编码器配置
encoder_config = type('Config', (), {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'encoder_params': {
'in_channels': in_channels,
'out_channels': latent_size,
'block_out_channels': encoder_block_out_channels,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
}
})()
# 2. 解码器配置
decoder_config = {
'latent_size': latent_size,
'feature_dims': decoder_feature_dims,
'sdf_dims': sdf_dims,
'layers_per_block': layers_per_block,
'norm_num_groups': norm_num_groups,
'dropout': dropout,
'dropout_prob': dropout_prob,
'norm_layers': norm_layers,
'latent_in': latent_in,
'weight_norm': weight_norm,
'xyz_in_all': xyz_in_all,
'use_tanh': use_tanh,
}
# 3. 创建编码器和解码器
self.encoder = BRepEncoder(encoder_config)
self.decoder = BRep2SdfDecoder(**decoder_config)
def encode(self, brep_model):
"""编码B-rep模型为潜在特征"""
return self.encoder.encode(brep_model)
def decode(self, latent, query_points, latent_embeds=None):
"""从潜在特征解码SDF值"""
return self.decoder(latent, query_points, latent_embeds)
def forward(self, brep_model, query_points):
"""完整的前向传播过程"""
# 1. 编码B-rep模型
latent = self.encode(brep_model)
if latent is None:
return None
# 2. 解码SDF值
sdf = self.decode(latent, query_points)
return sdf
# 使用示例
if __name__ == "__main__":
# 创建模型
model = BRep2SDF(
in_channels=3,
latent_size=256,
encoder_block_out_channels=(512, 256, 128, 64),
decoder_feature_dims=(512, 256, 128, 64),
sdf_dims=(512, 512, 512, 512),
layers_per_block=2,
norm_num_groups=32,
)
# 测试数据
batch_size = 4
seq_len = 32
num_points = 1000
# 模拟B-rep模型数据
class MockBRep:
def __init__(self):
self.faces = [MockFace() for _ in range(10)]
self.edges = [MockEdge() for _ in range(20)]
class MockFace:
def __init__(self):
self.center_point = torch.randn(3)
self.normal_vector = torch.randn(3)
self.surface_type = 0
self.edges = []
class MockEdge:
def __init__(self):
self.length = lambda: 1.0
self.point_at = lambda t: torch.randn(3)
brep_model = MockBRep()
query_points = torch.randn(batch_size, num_points, 3)
# 前向传播
sdf = model(brep_model, query_points)
if sdf is not None:
print(f"Output SDF shape: {sdf.shape}")