From 5b98c59270bf7024a04ef558db188783060c9939 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 23 Nov 2024 15:21:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=A8=A1=E5=9E=8B=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E5=87=BD=E6=95=B0=E6=8B=86=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/decoder.py | 403 +++-------------------------------- brep2sdf/networks/network.py | 295 +++++++++++++++---------- brep2sdf/train.py | 2 +- 3 files changed, 217 insertions(+), 483 deletions(-) diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index deebda0..26f26d7 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -1,383 +1,42 @@ -import math import torch import torch.nn as nn -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, is_torch_version -from diffusers.utils.accelerate_utils import apply_forward_hook -from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d -from diffusers.models.attention_processor import SpatialNorm - - +import torch.nn.functional as F +from typing import Dict, Optional, Tuple, Union +from brep2sdf.config.default_config import get_default_config +from brep2sdf.utils.logger import logger -class Decoder1D(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - norm_type="group", # group, spatial - ): - ''' - 这是第一阶段的解码器,用于处理B-rep特征 - 包含三个主要部分: - conv_in: 输入卷积层,处理初始特征 - mid_block: 中间处理块 - up_blocks: 上采样块序列 - 支持梯度检查点功能(gradient checkpointing)以节省内存 - 输出维度: [B, C, L] - # NOTE: - 1. 移除了分片(slicing)和平铺(tiling)功能 - 2. 直接使用mode()而不是sample()获取潜在向量 - 3. 简化了编码过程,只保留核心功能 - 4. 返回确定性的潜在向量而不是分布 - ''' +class SDFHead(nn.Module): + """SDF预测头""" + def __init__(self, embed_dim: int = 768*2): super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv1d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock1D( - in_channels=block_out_channels[-1], - mid_channels=block_out_channels[-1], + self.mlp = nn.Sequential( + nn.Linear(embed_dim, embed_dim//2), + nn.LayerNorm(embed_dim//2), + nn.ReLU(), + nn.Linear(embed_dim//2, embed_dim//4), + nn.LayerNorm(embed_dim//4), + nn.ReLU(), + nn.Linear(embed_dim//4, 1), + nn.Tanh() ) - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = UpBlock1D( - in_channels=prev_output_channel, - out_channels=output_channel, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - - def forward(self, z, latent_embeds=None): - sample = z - sample = self.conv_in(sample) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False - ) - # sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False - ) - else: - # middle - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.mid_block), sample, latent_embeds - ) - # sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) - else: - # middle - sample = self.mid_block(sample, latent_embeds) - # sample = sample.to(upscale_dtype) - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds) + def forward(self, x): + return self.mlp(x) - # post-process - if latent_embeds is None: - sample = self.conv_norm_out(sample) - else: - sample = self.conv_norm_out(sample, latent_embeds) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - - -class SDFDecoder(nn.Module): - def __init__( - self, - latent_size, - dims, - dropout=None, - dropout_prob=0.0, - norm_layers=(), - latent_in=(), - weight_norm=False, - xyz_in_all=None, - use_tanh=False, - latent_dropout=False, - ): - ''' - 这是第二阶段的解码器,用于生成SDF值 - 使用多层MLP结构 - 特点: - 支持在不同层注入latent信息(通过latent_in参数) - 可以在每层添加xyz坐标(通过xyz_in_all参数) - 支持权重归一化和dropout - 输入维度: [N, latent_size+3] - 输出维度: [N, 1] - - ''' - super(SDFDecoder, self).__init__() - - def make_sequence(): - return [] - - dims = [latent_size + 3] + dims + [1] - - self.num_layers = len(dims) - self.norm_layers = norm_layers - self.latent_in = latent_in - self.latent_dropout = latent_dropout - if self.latent_dropout: - self.lat_dp = nn.Dropout(0.2) - - self.xyz_in_all = xyz_in_all - self.weight_norm = weight_norm - - for layer in range(0, self.num_layers - 1): - if layer + 1 in latent_in: - out_dim = dims[layer + 1] - dims[0] - else: - out_dim = dims[layer + 1] - if self.xyz_in_all and layer != self.num_layers - 2: - out_dim -= 3 - - if weight_norm and layer in self.norm_layers: - setattr( - self, - "lin" + str(layer), - nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), - ) - else: - setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) - - if ( - (not weight_norm) - and self.norm_layers is not None - and layer in self.norm_layers - ): - setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) - - self.use_tanh = use_tanh - if use_tanh: - self.tanh = nn.Tanh() - self.relu = nn.ReLU() - - self.dropout_prob = dropout_prob - self.dropout = dropout - self.th = nn.Tanh() - - # input: N x (L+3) - def forward(self, input): - xyz = input[:, -3:] - - if input.shape[1] > 3 and self.latent_dropout: - latent_vecs = input[:, :-3] - latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) - x = torch.cat([latent_vecs, xyz], 1) - else: - x = input - - for layer in range(0, self.num_layers - 1): - lin = getattr(self, "lin" + str(layer)) - if layer in self.latent_in: - x = torch.cat([x, input], 1) - elif layer != 0 and self.xyz_in_all: - x = torch.cat([x, xyz], 1) - x = lin(x) - # last layer Tanh - if layer == self.num_layers - 2 and self.use_tanh: - x = self.tanh(x) - if layer < self.num_layers - 2: - if ( - self.norm_layers is not None - and layer in self.norm_layers - and not self.weight_norm - ): - bn = getattr(self, "bn" + str(layer)) - x = bn(x) - x = self.relu(x) - if self.dropout is not None and layer in self.dropout: - x = F.dropout(x, p=self.dropout_prob, training=self.training) - - if hasattr(self, "th"): - x = self.th(x) - - return x - - -class BRep2SdfDecoder(nn.Module): - def __init__( - self, - latent_size=256, - feature_dims=[512, 512, 256, 128], # 特征解码器维度 - sdf_dims=[512, 512, 512, 512], # SDF解码器维度 - up_block_types=("UpDecoderBlock2D",), - layers_per_block=2, - norm_num_groups=32, - norm_type="group", - dropout=None, - dropout_prob=0.0, - norm_layers=(), - latent_in=(), - weight_norm=False, - xyz_in_all=True, - use_tanh=True, - ): +class SDFTransformer(nn.Module): + """SDF Transformer编码器""" + def __init__(self, embed_dim: int = 768, num_layers: int = 6): super().__init__() - - # 1. 特征解码器 (使用Decoder1D结构) - self.feature_decoder = Decoder1D( - in_channels=latent_size, - out_channels=feature_dims[-1], - up_block_types=up_block_types, - block_out_channels=feature_dims, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - norm_type=norm_type - ) - - # 2. SDF解码器 (使用原始Decoder结构) - self.sdf_decoder = SDFDecoder( - latent_size=feature_dims[-1], # 使用特征解码器的输出维度 - dims=sdf_dims, - dropout=dropout, - dropout_prob=dropout_prob, - norm_layers=norm_layers, - latent_in=latent_in, - weight_norm=weight_norm, - xyz_in_all=xyz_in_all, - use_tanh=use_tanh, - latent_dropout=False + layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=8, + dim_feedforward=1024, + dropout=0.1, + batch_first=True, + norm_first=False # 修改这里:设置为False ) - - # 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) - self.feature_transform = nn.Sequential( - nn.Linear(feature_dims[-1], feature_dims[-1]), - nn.LayerNorm(feature_dims[-1]), - nn.SiLU() - ) - - def forward(self, latent, query_points, latent_embeds=None): - """ - Args: - latent: [B, C, L] B-rep特征 - query_points: [B, N, 3] 查询点 - latent_embeds: 可选的条件嵌入 - Returns: - sdf: [B, N, 1] SDF值 - """ - # 1. 特征解码 - features = self.feature_decoder(latent, latent_embeds) # [B, C, L] - - # 2. 特征转换 - B, C, L = features.shape - features = features.permute(0, 2, 1) # [B, L, C] - features = self.feature_transform(features) # [B, L, C] - - # 3. 准备SDF解码器输入 - _, N, _ = query_points.shape - features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C] - query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3] - - # 4. 合并特征和坐标 - sdf_input = torch.cat([ - features.reshape(B*N*L, -1), # [B*N*L, C] - query_points.reshape(B*N*L, -1) # [B*N*L, 3] - ], dim=-1) - - # 5. SDF生成 - sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1] - sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1] - - # 6. 聚合多尺度SDF - sdf = sdf.mean(dim=2) # [B, N, 1] - - return sdf + self.transformer = nn.TransformerEncoder(layer, num_layers) -# 使用示例 -if __name__ == "__main__": - # 创建模型 - decoder = BRepDecoder( - latent_size=256, - feature_dims=[512, 256, 128, 64], - sdf_dims=[512, 512, 512, 512], - up_block_types=("UpDecoderBlock2D",), - layers_per_block=2, - norm_num_groups=32, - dropout=None, - dropout_prob=0.0, - norm_layers=[0, 1, 2, 3], - latent_in=[4], - weight_norm=True, - xyz_in_all=True, - use_tanh=True - ) - - # 测试数据 - batch_size = 4 - seq_len = 32 - num_points = 1000 - - latent = torch.randn(batch_size, 256, seq_len) - query_points = torch.randn(batch_size, num_points, 3) - latent_embeds = torch.randn(batch_size, 256) - - # 前向传播 - sdf = decoder(latent, query_points, latent_embeds) - print(f"Input latent shape: {latent.shape}") - print(f"Query points shape: {query_points.shape}") - print(f"Output SDF shape: {sdf.shape}") + def forward(self, x, mask=None): + return self.transformer(x, src_key_padding_mask=mask) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 5c0147a..21279bc 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -1,127 +1,202 @@ import torch import torch.nn as nn -from encoder import BRepEncoder -from decoder import BRep2SdfDecoder +import torch.nn.functional as F +from typing import Dict, Optional, Tuple, Union +from brep2sdf.config.default_config import get_default_config +from brep2sdf.utils.logger import logger -class BRep2SDF(nn.Module): - def __init__( - self, - # 编码器参数 - in_channels=3, - latent_size=256, - encoder_block_out_channels=(512, 256, 128, 64), - # 解码器参数 - decoder_feature_dims=(512, 256, 128, 64), - sdf_dims=(512, 512, 512, 512), - # 共享参数 - layers_per_block=2, - norm_num_groups=32, - # SDF特定参数 - dropout=None, - dropout_prob=0.0, - norm_layers=(0, 1, 2, 3), - latent_in=(4,), - weight_norm=True, - xyz_in_all=True, - use_tanh=True, - ): +from brep2sdf.networks.encoder import BRepFeatureEmbedder +from brep2sdf.networks.decoder import SDFHead, SDFTransformer + + +class BRepToSDF(nn.Module): + def __init__(self, config=None): super().__init__() + # 获取配置 + if config is None: + self.config = get_default_config() + else: + self.config = config + + # 从配置中读取参数 + self.embed_dim = self.config.model.embed_dim + self.brep_feature_dim = self.config.model.brep_feature_dim + self.latent_dim = self.config.model.latent_dim + self.use_cf = self.config.model.use_cf - # 1. 编码器配置 - encoder_config = type('Config', (), { - 'in_channels': in_channels, - 'out_channels': latent_size, - 'block_out_channels': encoder_block_out_channels, - 'layers_per_block': layers_per_block, - 'norm_num_groups': norm_num_groups, - 'encoder_params': { - 'in_channels': in_channels, - 'out_channels': latent_size, - 'block_out_channels': encoder_block_out_channels, - 'layers_per_block': layers_per_block, - 'norm_num_groups': norm_num_groups, - } - })() + # 1. 查询点编码器 + self.query_encoder = nn.Sequential( + nn.Linear(3, self.embed_dim//4), + nn.LayerNorm(self.embed_dim//4), + nn.ReLU(), + nn.Linear(self.embed_dim//4, self.embed_dim//2), + nn.LayerNorm(self.embed_dim//2), + nn.ReLU(), + nn.Linear(self.embed_dim//2, self.embed_dim) + ) - # 2. 解码器配置 - decoder_config = { - 'latent_size': latent_size, - 'feature_dims': decoder_feature_dims, - 'sdf_dims': sdf_dims, - 'layers_per_block': layers_per_block, - 'norm_num_groups': norm_num_groups, - 'dropout': dropout, - 'dropout_prob': dropout_prob, - 'norm_layers': norm_layers, - 'latent_in': latent_in, - 'weight_norm': weight_norm, - 'xyz_in_all': xyz_in_all, - 'use_tanh': use_tanh, - } + # 2. B-rep特征编码器 + self.brep_embedder = BRepFeatureEmbedder() - # 3. 创建编码器和解码器 - self.encoder = BRepEncoder(encoder_config) - self.decoder = BRep2SdfDecoder(**decoder_config) + # 3. 特征融合Transformer + self.transformer = SDFTransformer( + embed_dim=self.embed_dim, + num_layers=6 # 这个参数也可以移到配置文件中 + ) - def encode(self, brep_model): - """编码B-rep模型为潜在特征""" - return self.encoder.encode(brep_model) + # 4. SDF预测头 + self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) + + def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): + """B-rep到SDF的前向传播 - def decode(self, latent, query_points, latent_embeds=None): - """从潜在特征解码SDF值""" - return self.decoder(latent, query_points, latent_embeds) - - def forward(self, brep_model, query_points): - """完整的前向传播过程""" - # 1. 编码B-rep模型 - latent = self.encode(brep_model) - if latent is None: - return None + Args: + edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] + edge_pos: 边位置 [B, max_face, max_edge, 6] + edge_mask: 边掩码 [B, max_face, max_edge] + surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] + surf_pos: 面位置 [B, max_face, 6] + vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] + query_points: 查询点 [B, num_queries, 3] + data_class: (可选) 类别标签 + + Returns: + sdf: 预测的SDF值 [B, num_queries, 1] + """ + B, Q = query_points.shape[:2] # B: batch_size, Q: num_queries + + try: + # 确保query_points需要梯度 + if not query_points.requires_grad: + query_points = query_points.detach().requires_grad_(True) + + + # 1. B-rep特征编码 + brep_features = self.brep_embedder( + edge_ncs=edge_ncs, # [B, max_face, max_edge, num_edge_points, 3] + edge_pos=edge_pos, # [B, max_face, max_edge, 6] + edge_mask=edge_mask, # [B, max_face, max_edge] + surf_ncs=surf_ncs, # [B, max_face, num_surf_points, 3] + surf_pos=surf_pos, # [B, max_face, 6] + vertex_pos=vertex_pos, # [B, max_face, max_edge, 2, 3] + data_class=data_class + ) # [B, max_face*(max_edge+1), embed_dim] - # 2. 解码SDF值 - sdf = self.decode(latent, query_points) - return sdf + # 2. 查询点编码 + query_features = self.query_encoder(query_points) # [B, Q, embed_dim] + + # 3. 提取全局特征 + global_features = brep_features.mean(dim=1) # [B, embed_dim] + + # 4. 为每个查询点准备特征 + expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] + + # 5. 连接查询点特征和全局特征 + combined_features = torch.cat([ + expanded_features, # [B, Q, embed_dim] + query_features # [B, Q, embed_dim] + ], dim=-1) # [B, Q, embed_dim*2] + + # 6. SDF预测 + sdf = self.sdf_head(combined_features) # [B, Q, 1] -# 使用示例 -if __name__ == "__main__": - # 创建模型 - model = BRep2SDF( - in_channels=3, - latent_size=256, - encoder_block_out_channels=(512, 256, 128, 64), - decoder_feature_dims=(512, 256, 128, 64), - sdf_dims=(512, 512, 512, 512), - layers_per_block=2, - norm_num_groups=32, - ) + if not sdf.requires_grad: + logger.warning("SDF output does not require grad!") + + + return sdf + + except Exception as e: + logger.error(f"Error in BRepToSDF forward pass:") + logger.error(f" Error message: {str(e)}") + logger.error(f" Input shapes:") + logger.error(f" edge_ncs: {edge_ncs.shape}") + logger.error(f" edge_pos: {edge_pos.shape}") + logger.error(f" edge_mask: {edge_mask.shape}") + logger.error(f" surf_ncs: {surf_ncs.shape}") + logger.error(f" surf_pos: {surf_pos.shape}") + logger.error(f" vertex_pos: {vertex_pos.shape}") + logger.error(f" query_points: {query_points.shape}") + raise + +def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): + """SDF损失函数""" + # 确保points需要梯度 + if not points.requires_grad: + points = points.detach().requires_grad_(True) - # 测试数据 - batch_size = 4 - seq_len = 32 - num_points = 1000 + # L1损失 + l1_loss = F.l1_loss(pred_sdf, gt_sdf) - # 模拟B-rep模型数据 - class MockBRep: - def __init__(self): - self.faces = [MockFace() for _ in range(10)] - self.edges = [MockEdge() for _ in range(20)] - - class MockFace: - def __init__(self): - self.center_point = torch.randn(3) - self.normal_vector = torch.randn(3) - self.surface_type = 0 - self.edges = [] + try: + # 梯度约束损失 + grad = torch.autograd.grad( + pred_sdf.sum(), + points, + create_graph=True, + retain_graph=True, + allow_unused=True + )[0] + + if grad is not None: + grad_constraint = F.mse_loss( + torch.norm(grad, dim=-1), + torch.ones_like(pred_sdf.squeeze(-1)) + ) + else: + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) - class MockEdge: - def __init__(self): - self.length = lambda: 1.0 - self.point_at = lambda t: torch.randn(3) + except Exception as e: + logger.warning(f"Gradient computation failed: {str(e)}") + grad_constraint = torch.tensor(0.0, device=pred_sdf.device) + + return l1_loss + grad_weight * grad_constraint + +def main(): + # 获取配置 + config = get_default_config() + + # 初始化模型 + model = BRepToSDF(config=config) - brep_model = MockBRep() - query_points = torch.randn(batch_size, num_points, 3) + # 从配置获取参数 + batch_size = config.train.batch_size + max_face = config.data.max_face + max_edge = config.data.max_edge + num_surf_points = config.model.num_surf_points + num_edge_points = config.model.num_edge_points + + # 生成测试数据 + test_data = { + 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), + 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), + 'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), + 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), + 'surf_pos': torch.randn(batch_size, max_face, 6), + 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), + 'query_points': torch.randn(batch_size, 1000, 3) # 1000个查询点 + } + + # 打印输入数据形状 + logger.info("Input shapes:") + for name, tensor in test_data.items(): + logger.info(f" {name}: {tensor.shape}") # 前向传播 - sdf = model(brep_model, query_points) - if sdf is not None: - print(f"Output SDF shape: {sdf.shape}") \ No newline at end of file + try: + sdf = model(**test_data) + logger.info(f"\nOutput SDF shape: {sdf.shape}") + + # 计算模型参数量 + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"\nModel statistics:") + logger.info(f" Total parameters: {total_params:,}") + logger.info(f" Trainable parameters: {trainable_params:,}") + + except Exception as e: + logger.error(f"Error during forward pass: {str(e)}") + raise + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 1a85692..b955840 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from brep2sdf.data.data import BRepSDFDataset -from brep2sdf.networks.encoder import BRepToSDF, sdf_loss +from brep2sdf.networks.network import BRepToSDF, sdf_loss from brep2sdf.utils.logger import logger from brep2sdf.config.default_config import get_default_config, load_config import wandb