import torch import torch.nn as nn from encoder import BRepEncoder from decoder import BRep2SdfDecoder class BRep2SDF(nn.Module): def __init__( self, # 编码器参数 in_channels=3, latent_size=256, encoder_block_out_channels=(512, 256, 128, 64), # 解码器参数 decoder_feature_dims=(512, 256, 128, 64), sdf_dims=(512, 512, 512, 512), # 共享参数 layers_per_block=2, norm_num_groups=32, # SDF特定参数 dropout=None, dropout_prob=0.0, norm_layers=(0, 1, 2, 3), latent_in=(4,), weight_norm=True, xyz_in_all=True, use_tanh=True, ): super().__init__() # 1. 编码器配置 encoder_config = type('Config', (), { 'in_channels': in_channels, 'out_channels': latent_size, 'block_out_channels': encoder_block_out_channels, 'layers_per_block': layers_per_block, 'norm_num_groups': norm_num_groups, 'encoder_params': { 'in_channels': in_channels, 'out_channels': latent_size, 'block_out_channels': encoder_block_out_channels, 'layers_per_block': layers_per_block, 'norm_num_groups': norm_num_groups, } })() # 2. 解码器配置 decoder_config = { 'latent_size': latent_size, 'feature_dims': decoder_feature_dims, 'sdf_dims': sdf_dims, 'layers_per_block': layers_per_block, 'norm_num_groups': norm_num_groups, 'dropout': dropout, 'dropout_prob': dropout_prob, 'norm_layers': norm_layers, 'latent_in': latent_in, 'weight_norm': weight_norm, 'xyz_in_all': xyz_in_all, 'use_tanh': use_tanh, } # 3. 创建编码器和解码器 self.encoder = BRepEncoder(encoder_config) self.decoder = BRep2SdfDecoder(**decoder_config) def encode(self, brep_model): """编码B-rep模型为潜在特征""" return self.encoder.encode(brep_model) def decode(self, latent, query_points, latent_embeds=None): """从潜在特征解码SDF值""" return self.decoder(latent, query_points, latent_embeds) def forward(self, brep_model, query_points): """完整的前向传播过程""" # 1. 编码B-rep模型 latent = self.encode(brep_model) if latent is None: return None # 2. 解码SDF值 sdf = self.decode(latent, query_points) return sdf # 使用示例 if __name__ == "__main__": # 创建模型 model = BRep2SDF( in_channels=3, latent_size=256, encoder_block_out_channels=(512, 256, 128, 64), decoder_feature_dims=(512, 256, 128, 64), sdf_dims=(512, 512, 512, 512), layers_per_block=2, norm_num_groups=32, ) # 测试数据 batch_size = 4 seq_len = 32 num_points = 1000 # 模拟B-rep模型数据 class MockBRep: def __init__(self): self.faces = [MockFace() for _ in range(10)] self.edges = [MockEdge() for _ in range(20)] class MockFace: def __init__(self): self.center_point = torch.randn(3) self.normal_vector = torch.randn(3) self.surface_type = 0 self.edges = [] class MockEdge: def __init__(self): self.length = lambda: 1.0 self.point_at = lambda t: torch.randn(3) brep_model = MockBRep() query_points = torch.randn(batch_size, num_points, 3) # 前向传播 sdf = model(brep_model, query_points) if sdf is not None: print(f"Output SDF shape: {sdf.shape}")