You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
3.9 KiB
127 lines
3.9 KiB
import torch
|
|
import torch.nn as nn
|
|
from encoder import BRepEncoder
|
|
from decoder import BRep2SdfDecoder
|
|
|
|
class BRep2SDF(nn.Module):
|
|
def __init__(
|
|
self,
|
|
# 编码器参数
|
|
in_channels=3,
|
|
latent_size=256,
|
|
encoder_block_out_channels=(512, 256, 128, 64),
|
|
# 解码器参数
|
|
decoder_feature_dims=(512, 256, 128, 64),
|
|
sdf_dims=(512, 512, 512, 512),
|
|
# 共享参数
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
# SDF特定参数
|
|
dropout=None,
|
|
dropout_prob=0.0,
|
|
norm_layers=(0, 1, 2, 3),
|
|
latent_in=(4,),
|
|
weight_norm=True,
|
|
xyz_in_all=True,
|
|
use_tanh=True,
|
|
):
|
|
super().__init__()
|
|
|
|
# 1. 编码器配置
|
|
encoder_config = type('Config', (), {
|
|
'in_channels': in_channels,
|
|
'out_channels': latent_size,
|
|
'block_out_channels': encoder_block_out_channels,
|
|
'layers_per_block': layers_per_block,
|
|
'norm_num_groups': norm_num_groups,
|
|
'encoder_params': {
|
|
'in_channels': in_channels,
|
|
'out_channels': latent_size,
|
|
'block_out_channels': encoder_block_out_channels,
|
|
'layers_per_block': layers_per_block,
|
|
'norm_num_groups': norm_num_groups,
|
|
}
|
|
})()
|
|
|
|
# 2. 解码器配置
|
|
decoder_config = {
|
|
'latent_size': latent_size,
|
|
'feature_dims': decoder_feature_dims,
|
|
'sdf_dims': sdf_dims,
|
|
'layers_per_block': layers_per_block,
|
|
'norm_num_groups': norm_num_groups,
|
|
'dropout': dropout,
|
|
'dropout_prob': dropout_prob,
|
|
'norm_layers': norm_layers,
|
|
'latent_in': latent_in,
|
|
'weight_norm': weight_norm,
|
|
'xyz_in_all': xyz_in_all,
|
|
'use_tanh': use_tanh,
|
|
}
|
|
|
|
# 3. 创建编码器和解码器
|
|
self.encoder = BRepEncoder(encoder_config)
|
|
self.decoder = BRep2SdfDecoder(**decoder_config)
|
|
|
|
def encode(self, brep_model):
|
|
"""编码B-rep模型为潜在特征"""
|
|
return self.encoder.encode(brep_model)
|
|
|
|
def decode(self, latent, query_points, latent_embeds=None):
|
|
"""从潜在特征解码SDF值"""
|
|
return self.decoder(latent, query_points, latent_embeds)
|
|
|
|
def forward(self, brep_model, query_points):
|
|
"""完整的前向传播过程"""
|
|
# 1. 编码B-rep模型
|
|
latent = self.encode(brep_model)
|
|
if latent is None:
|
|
return None
|
|
|
|
# 2. 解码SDF值
|
|
sdf = self.decode(latent, query_points)
|
|
return sdf
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
# 创建模型
|
|
model = BRep2SDF(
|
|
in_channels=3,
|
|
latent_size=256,
|
|
encoder_block_out_channels=(512, 256, 128, 64),
|
|
decoder_feature_dims=(512, 256, 128, 64),
|
|
sdf_dims=(512, 512, 512, 512),
|
|
layers_per_block=2,
|
|
norm_num_groups=32,
|
|
)
|
|
|
|
# 测试数据
|
|
batch_size = 4
|
|
seq_len = 32
|
|
num_points = 1000
|
|
|
|
# 模拟B-rep模型数据
|
|
class MockBRep:
|
|
def __init__(self):
|
|
self.faces = [MockFace() for _ in range(10)]
|
|
self.edges = [MockEdge() for _ in range(20)]
|
|
|
|
class MockFace:
|
|
def __init__(self):
|
|
self.center_point = torch.randn(3)
|
|
self.normal_vector = torch.randn(3)
|
|
self.surface_type = 0
|
|
self.edges = []
|
|
|
|
class MockEdge:
|
|
def __init__(self):
|
|
self.length = lambda: 1.0
|
|
self.point_at = lambda t: torch.randn(3)
|
|
|
|
brep_model = MockBRep()
|
|
query_points = torch.randn(batch_size, num_points, 3)
|
|
|
|
# 前向传播
|
|
sdf = model(brep_model, query_points)
|
|
if sdf is not None:
|
|
print(f"Output SDF shape: {sdf.shape}")
|