diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/decoder.py b/networks/decoder.py new file mode 100644 index 0000000..deebda0 --- /dev/null +++ b/networks/decoder.py @@ -0,0 +1,383 @@ +import math +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d +from diffusers.models.attention_processor import SpatialNorm + + + + +class Decoder1D(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + norm_type="group", # group, spatial + ): + ''' + 这是第一阶段的解码器,用于处理B-rep特征 + 包含三个主要部分: + conv_in: 输入卷积层,处理初始特征 + mid_block: 中间处理块 + up_blocks: 上采样块序列 + 支持梯度检查点功能(gradient checkpointing)以节省内存 + 输出维度: [B, C, L] + # NOTE: + 1. 移除了分片(slicing)和平铺(tiling)功能 + 2. 直接使用mode()而不是sample()获取潜在向量 + 3. 简化了编码过程,只保留核心功能 + 4. 返回确定性的潜在向量而不是分布 + ''' + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv1d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock1D( + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpBlock1D( + in_channels=prev_output_channel, + out_channels=output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + + def forward(self, z, latent_embeds=None): + sample = z + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False + ) + # sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + # sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + # sample = sample.to(upscale_dtype) + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + + +class SDFDecoder(nn.Module): + def __init__( + self, + latent_size, + dims, + dropout=None, + dropout_prob=0.0, + norm_layers=(), + latent_in=(), + weight_norm=False, + xyz_in_all=None, + use_tanh=False, + latent_dropout=False, + ): + ''' + 这是第二阶段的解码器,用于生成SDF值 + 使用多层MLP结构 + 特点: + 支持在不同层注入latent信息(通过latent_in参数) + 可以在每层添加xyz坐标(通过xyz_in_all参数) + 支持权重归一化和dropout + 输入维度: [N, latent_size+3] + 输出维度: [N, 1] + + ''' + super(SDFDecoder, self).__init__() + + def make_sequence(): + return [] + + dims = [latent_size + 3] + dims + [1] + + self.num_layers = len(dims) + self.norm_layers = norm_layers + self.latent_in = latent_in + self.latent_dropout = latent_dropout + if self.latent_dropout: + self.lat_dp = nn.Dropout(0.2) + + self.xyz_in_all = xyz_in_all + self.weight_norm = weight_norm + + for layer in range(0, self.num_layers - 1): + if layer + 1 in latent_in: + out_dim = dims[layer + 1] - dims[0] + else: + out_dim = dims[layer + 1] + if self.xyz_in_all and layer != self.num_layers - 2: + out_dim -= 3 + + if weight_norm and layer in self.norm_layers: + setattr( + self, + "lin" + str(layer), + nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), + ) + else: + setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) + + if ( + (not weight_norm) + and self.norm_layers is not None + and layer in self.norm_layers + ): + setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) + + self.use_tanh = use_tanh + if use_tanh: + self.tanh = nn.Tanh() + self.relu = nn.ReLU() + + self.dropout_prob = dropout_prob + self.dropout = dropout + self.th = nn.Tanh() + + # input: N x (L+3) + def forward(self, input): + xyz = input[:, -3:] + + if input.shape[1] > 3 and self.latent_dropout: + latent_vecs = input[:, :-3] + latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) + x = torch.cat([latent_vecs, xyz], 1) + else: + x = input + + for layer in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(layer)) + if layer in self.latent_in: + x = torch.cat([x, input], 1) + elif layer != 0 and self.xyz_in_all: + x = torch.cat([x, xyz], 1) + x = lin(x) + # last layer Tanh + if layer == self.num_layers - 2 and self.use_tanh: + x = self.tanh(x) + if layer < self.num_layers - 2: + if ( + self.norm_layers is not None + and layer in self.norm_layers + and not self.weight_norm + ): + bn = getattr(self, "bn" + str(layer)) + x = bn(x) + x = self.relu(x) + if self.dropout is not None and layer in self.dropout: + x = F.dropout(x, p=self.dropout_prob, training=self.training) + + if hasattr(self, "th"): + x = self.th(x) + + return x + + +class BRep2SdfDecoder(nn.Module): + def __init__( + self, + latent_size=256, + feature_dims=[512, 512, 256, 128], # 特征解码器维度 + sdf_dims=[512, 512, 512, 512], # SDF解码器维度 + up_block_types=("UpDecoderBlock2D",), + layers_per_block=2, + norm_num_groups=32, + norm_type="group", + dropout=None, + dropout_prob=0.0, + norm_layers=(), + latent_in=(), + weight_norm=False, + xyz_in_all=True, + use_tanh=True, + ): + super().__init__() + + # 1. 特征解码器 (使用Decoder1D结构) + self.feature_decoder = Decoder1D( + in_channels=latent_size, + out_channels=feature_dims[-1], + up_block_types=up_block_types, + block_out_channels=feature_dims, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + norm_type=norm_type + ) + + # 2. SDF解码器 (使用原始Decoder结构) + self.sdf_decoder = SDFDecoder( + latent_size=feature_dims[-1], # 使用特征解码器的输出维度 + dims=sdf_dims, + dropout=dropout, + dropout_prob=dropout_prob, + norm_layers=norm_layers, + latent_in=latent_in, + weight_norm=weight_norm, + xyz_in_all=xyz_in_all, + use_tanh=use_tanh, + latent_dropout=False + ) + + # 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) + self.feature_transform = nn.Sequential( + nn.Linear(feature_dims[-1], feature_dims[-1]), + nn.LayerNorm(feature_dims[-1]), + nn.SiLU() + ) + + def forward(self, latent, query_points, latent_embeds=None): + """ + Args: + latent: [B, C, L] B-rep特征 + query_points: [B, N, 3] 查询点 + latent_embeds: 可选的条件嵌入 + Returns: + sdf: [B, N, 1] SDF值 + """ + # 1. 特征解码 + features = self.feature_decoder(latent, latent_embeds) # [B, C, L] + + # 2. 特征转换 + B, C, L = features.shape + features = features.permute(0, 2, 1) # [B, L, C] + features = self.feature_transform(features) # [B, L, C] + + # 3. 准备SDF解码器输入 + _, N, _ = query_points.shape + features = features.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, L, C] + query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1) # [B, N, L, 3] + + # 4. 合并特征和坐标 + sdf_input = torch.cat([ + features.reshape(B*N*L, -1), # [B*N*L, C] + query_points.reshape(B*N*L, -1) # [B*N*L, 3] + ], dim=-1) + + # 5. SDF生成 + sdf = self.sdf_decoder(sdf_input) # [B*N*L, 1] + sdf = sdf.reshape(B, N, L, 1) # [B, N, L, 1] + + # 6. 聚合多尺度SDF + sdf = sdf.mean(dim=2) # [B, N, 1] + + return sdf + +# 使用示例 +if __name__ == "__main__": + # 创建模型 + decoder = BRepDecoder( + latent_size=256, + feature_dims=[512, 256, 128, 64], + sdf_dims=[512, 512, 512, 512], + up_block_types=("UpDecoderBlock2D",), + layers_per_block=2, + norm_num_groups=32, + dropout=None, + dropout_prob=0.0, + norm_layers=[0, 1, 2, 3], + latent_in=[4], + weight_norm=True, + xyz_in_all=True, + use_tanh=True + ) + + # 测试数据 + batch_size = 4 + seq_len = 32 + num_points = 1000 + + latent = torch.randn(batch_size, 256, seq_len) + query_points = torch.randn(batch_size, num_points, 3) + latent_embeds = torch.randn(batch_size, 256) + + # 前向传播 + sdf = decoder(latent, query_points, latent_embeds) + print(f"Input latent shape: {latent.shape}") + print(f"Query points shape: {query_points.shape}") + print(f"Output SDF shape: {sdf.shape}") diff --git a/networks/deep_sdf_decoder.py b/networks/deep_sdf_decoder.py deleted file mode 100644 index 71a0b02..0000000 --- a/networks/deep_sdf_decoder.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2004-present Facebook. All Rights Reserved. - -import torch.nn as nn -import torch -import torch.nn.functional as F - - -class Decoder(nn.Module): - def __init__( - self, - latent_size, - dims, - dropout=None, - dropout_prob=0.0, - norm_layers=(), - latent_in=(), - weight_norm=False, - xyz_in_all=None, - use_tanh=False, - latent_dropout=False, - ): - super(Decoder, self).__init__() - - def make_sequence(): - return [] - - dims = [latent_size + 3] + dims + [1] - - self.num_layers = len(dims) - self.norm_layers = norm_layers - self.latent_in = latent_in - self.latent_dropout = latent_dropout - if self.latent_dropout: - self.lat_dp = nn.Dropout(0.2) - - self.xyz_in_all = xyz_in_all - self.weight_norm = weight_norm - - for layer in range(0, self.num_layers - 1): - if layer + 1 in latent_in: - out_dim = dims[layer + 1] - dims[0] - else: - out_dim = dims[layer + 1] - if self.xyz_in_all and layer != self.num_layers - 2: - out_dim -= 3 - - if weight_norm and layer in self.norm_layers: - setattr( - self, - "lin" + str(layer), - nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), - ) - else: - setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) - - if ( - (not weight_norm) - and self.norm_layers is not None - and layer in self.norm_layers - ): - setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) - - self.use_tanh = use_tanh - if use_tanh: - self.tanh = nn.Tanh() - self.relu = nn.ReLU() - - self.dropout_prob = dropout_prob - self.dropout = dropout - self.th = nn.Tanh() - - # input: N x (L+3) - def forward(self, input): - xyz = input[:, -3:] - - if input.shape[1] > 3 and self.latent_dropout: - latent_vecs = input[:, :-3] - latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) - x = torch.cat([latent_vecs, xyz], 1) - else: - x = input - - for layer in range(0, self.num_layers - 1): - lin = getattr(self, "lin" + str(layer)) - if layer in self.latent_in: - x = torch.cat([x, input], 1) - elif layer != 0 and self.xyz_in_all: - x = torch.cat([x, xyz], 1) - x = lin(x) - # last layer Tanh - if layer == self.num_layers - 2 and self.use_tanh: - x = self.tanh(x) - if layer < self.num_layers - 2: - if ( - self.norm_layers is not None - and layer in self.norm_layers - and not self.weight_norm - ): - bn = getattr(self, "bn" + str(layer)) - x = bn(x) - x = self.relu(x) - if self.dropout is not None and layer in self.dropout: - x = F.dropout(x, p=self.dropout_prob, training=self.training) - - if hasattr(self, "th"): - x = self.th(x) - - return x diff --git a/networks/encoder.py b/networks/encoder.py new file mode 100644 index 0000000..6d85c8f --- /dev/null +++ b/networks/encoder.py @@ -0,0 +1,356 @@ +import math +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d +from diffusers.models.attention_processor import SpatialNorm + +''' +# NOTE: + 移除了分片(slicing)和平铺(tiling)功能 + 直接使用mode()而不是sample()获取潜在向量 + 简化了编码过程,只保留核心功能 + 返回确定性的潜在向量而不是分布 +''' + +# 1. 基础网络组件 +class Embedder(nn.Module): + def __init__(self, vocab_size, d_model): + super().__init__() + self.embed = nn.Embedding(vocab_size, d_model) + self._init_embeddings() + + def _init_embeddings(self): + nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") + + def forward(self, x): + return self.embed(x) + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward(self, hidden_states, temb=None): + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + hidden_states = self.up(hidden_states) + return hidden_states + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states + + +class Encoder1D(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock1D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv1d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock1D( + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x): + sample = x + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample)[0] + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + +# 2. B-rep特征处理 +class BRepFeatureExtractor: + def __init__(self, config): + self.encoder = Encoder1D( + in_channels=config.in_channels, # 根据特征维度设置 + out_channels=config.out_channels, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block + ) + + def extract_face_features(self, face): + """提取面的特征""" + features = [] + try: + # 基本几何特征 + center = face.center_point + normal = face.normal_vector + + # 边界特征 + bounds = face.bounds() + + # 曲面类型特征 + surface_type = face.surface_type + + # 组合特征 + feature = np.concatenate([ + center, # [3] + normal, # [3] + bounds.flatten(), + [surface_type] # 可以用one-hot编码 + ]) + features.append(feature) + + except Exception as e: + print(f"Error extracting face features: {e}") + + return np.array(features) + + def extract_edge_features(self, edge): + """提取边的特征""" + features = [] + try: + # 采样点 + points = self.sample_points_on_edge(edge) + + for point in points: + # 位置 + pos = point.coordinates + # 切向量 + tangent = point.tangent + # 曲率 + curvature = point.curvature + + point_feature = np.concatenate([ + pos, # [3] + tangent, # [3] + [curvature] # [1] + ]) + features.append(point_feature) + + except Exception as e: + print(f"Error extracting edge features: {e}") + + return np.array(features) + + @staticmethod + def sample_points_on_edge(edge, num_points=32): + """在边上均匀采样点""" + points = [] + try: + length = edge.length() + for i in range(num_points): + t = i / (num_points - 1) + point = edge.point_at(t * length) + points.append(point) + except Exception as e: + print(f"Error sampling points: {e}") + return points + +class BRepDataProcessor: + def __init__(self, feature_extractor): + self.feature_extractor = feature_extractor + + def process_brep(self, brep_model): + """处理单个B-rep模型""" + try: + # 1. 提取面特征 + face_features = [] + for face in brep_model.faces: + feat = self.feature_extractor.extract_face_features(face) + face_features.append(feat) + + # 2. 提取边特征 + edge_features = [] + for edge in brep_model.edges: + feat = self.feature_extractor.extract_edge_features(edge) + edge_features.append(feat) + + # 3. 组织数据结构 + return { + 'face_features': torch.tensor(face_features), + 'edge_features': torch.tensor(edge_features), + 'topology': self.extract_topology(brep_model) + } + + except Exception as e: + print(f"Error processing B-rep: {e}") + return None + + def extract_topology(self, brep_model): + """提取拓扑关系""" + # 面-边关系矩阵 + face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges))) + # 填充邻接关系 + for i, face in enumerate(brep_model.faces): + for j, edge in enumerate(brep_model.edges): + if edge in face.edges: + face_edge_adj[i,j] = 1 + return face_edge_adj + +# 3. 主编码器 +class BRepEncoder: + def __init__(self, config): + self.processor = BRepDataProcessor( + BRepFeatureExtractor(config) + ) + self.encoder = Encoder1D(**config.encoder_params) + + def encode(self, brep_model): + """编码B-rep模型""" + try: + # 1. 处理原始数据 + processed_data = self.processor.process_brep(brep_model) + if processed_data is None: + return None + + # 2. 特征编码 + face_features = self.encoder(processed_data['face_features']) + edge_features = self.encoder(processed_data['edge_features']) + + # 3. 组合特征 + combined_features = self.combine_features( + face_features, + edge_features, + processed_data['topology'] + ) + + return combined_features + + except Exception as e: + print(f"Error encoding B-rep: {e}") + return None + + def combine_features(self, face_features, edge_features, topology): + """组合不同类型的特征""" + # 可以使用图神经网络或者注意力机制来组合特征 + combined = torch.cat([ + face_features.mean(dim=1), # 全局面特征 + edge_features.mean(dim=1), # 全局边特征 + topology.flatten() # 拓扑信息 + ], dim=-1) + return combined \ No newline at end of file diff --git a/networks/network.py b/networks/network.py new file mode 100644 index 0000000..5c0147a --- /dev/null +++ b/networks/network.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +from encoder import BRepEncoder +from decoder import BRep2SdfDecoder + +class BRep2SDF(nn.Module): + def __init__( + self, + # 编码器参数 + in_channels=3, + latent_size=256, + encoder_block_out_channels=(512, 256, 128, 64), + # 解码器参数 + decoder_feature_dims=(512, 256, 128, 64), + sdf_dims=(512, 512, 512, 512), + # 共享参数 + layers_per_block=2, + norm_num_groups=32, + # SDF特定参数 + dropout=None, + dropout_prob=0.0, + norm_layers=(0, 1, 2, 3), + latent_in=(4,), + weight_norm=True, + xyz_in_all=True, + use_tanh=True, + ): + super().__init__() + + # 1. 编码器配置 + encoder_config = type('Config', (), { + 'in_channels': in_channels, + 'out_channels': latent_size, + 'block_out_channels': encoder_block_out_channels, + 'layers_per_block': layers_per_block, + 'norm_num_groups': norm_num_groups, + 'encoder_params': { + 'in_channels': in_channels, + 'out_channels': latent_size, + 'block_out_channels': encoder_block_out_channels, + 'layers_per_block': layers_per_block, + 'norm_num_groups': norm_num_groups, + } + })() + + # 2. 解码器配置 + decoder_config = { + 'latent_size': latent_size, + 'feature_dims': decoder_feature_dims, + 'sdf_dims': sdf_dims, + 'layers_per_block': layers_per_block, + 'norm_num_groups': norm_num_groups, + 'dropout': dropout, + 'dropout_prob': dropout_prob, + 'norm_layers': norm_layers, + 'latent_in': latent_in, + 'weight_norm': weight_norm, + 'xyz_in_all': xyz_in_all, + 'use_tanh': use_tanh, + } + + # 3. 创建编码器和解码器 + self.encoder = BRepEncoder(encoder_config) + self.decoder = BRep2SdfDecoder(**decoder_config) + + def encode(self, brep_model): + """编码B-rep模型为潜在特征""" + return self.encoder.encode(brep_model) + + def decode(self, latent, query_points, latent_embeds=None): + """从潜在特征解码SDF值""" + return self.decoder(latent, query_points, latent_embeds) + + def forward(self, brep_model, query_points): + """完整的前向传播过程""" + # 1. 编码B-rep模型 + latent = self.encode(brep_model) + if latent is None: + return None + + # 2. 解码SDF值 + sdf = self.decode(latent, query_points) + return sdf + +# 使用示例 +if __name__ == "__main__": + # 创建模型 + model = BRep2SDF( + in_channels=3, + latent_size=256, + encoder_block_out_channels=(512, 256, 128, 64), + decoder_feature_dims=(512, 256, 128, 64), + sdf_dims=(512, 512, 512, 512), + layers_per_block=2, + norm_num_groups=32, + ) + + # 测试数据 + batch_size = 4 + seq_len = 32 + num_points = 1000 + + # 模拟B-rep模型数据 + class MockBRep: + def __init__(self): + self.faces = [MockFace() for _ in range(10)] + self.edges = [MockEdge() for _ in range(20)] + + class MockFace: + def __init__(self): + self.center_point = torch.randn(3) + self.normal_vector = torch.randn(3) + self.surface_type = 0 + self.edges = [] + + class MockEdge: + def __init__(self): + self.length = lambda: 1.0 + self.point_at = lambda t: torch.randn(3) + + brep_model = MockBRep() + query_points = torch.randn(batch_size, num_points, 3) + + # 前向传播 + sdf = model(brep_model, query_points) + if sdf is not None: + print(f"Output SDF shape: {sdf.shape}") \ No newline at end of file