3 changed files with 217 additions and 483 deletions
			
			
		@ -1,383 +1,42 @@ | 
				
			|||
import math | 
				
			|||
import torch | 
				
			|||
import torch.nn as nn | 
				
			|||
from dataclasses import dataclass | 
				
			|||
from typing import Dict, Optional, Tuple, Union | 
				
			|||
 | 
				
			|||
from diffusers.configuration_utils import ConfigMixin, register_to_config | 
				
			|||
from diffusers.utils import BaseOutput, is_torch_version | 
				
			|||
from diffusers.utils.accelerate_utils import apply_forward_hook | 
				
			|||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor | 
				
			|||
from diffusers.models.modeling_utils import ModelMixin | 
				
			|||
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder | 
				
			|||
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d | 
				
			|||
from diffusers.models.attention_processor import SpatialNorm | 
				
			|||
 | 
				
			|||
 | 
				
			|||
import torch.nn.functional as F | 
				
			|||
 | 
				
			|||
from typing import Dict, Optional, Tuple, Union | 
				
			|||
from brep2sdf.config.default_config import get_default_config | 
				
			|||
from brep2sdf.utils.logger import logger | 
				
			|||
 | 
				
			|||
class Decoder1D(nn.Module): | 
				
			|||
    def __init__( | 
				
			|||
        self, | 
				
			|||
        in_channels=3, | 
				
			|||
        out_channels=3, | 
				
			|||
        up_block_types=("UpDecoderBlock2D",), | 
				
			|||
        block_out_channels=(64,), | 
				
			|||
        layers_per_block=2, | 
				
			|||
        norm_num_groups=32, | 
				
			|||
        act_fn="silu", | 
				
			|||
        norm_type="group",  # group, spatial | 
				
			|||
    ): | 
				
			|||
    ''' | 
				
			|||
    这是第一阶段的解码器,用于处理B-rep特征 | 
				
			|||
    包含三个主要部分: | 
				
			|||
    conv_in: 输入卷积层,处理初始特征 | 
				
			|||
    mid_block: 中间处理块 | 
				
			|||
    up_blocks: 上采样块序列 | 
				
			|||
    支持梯度检查点功能(gradient checkpointing)以节省内存 | 
				
			|||
    输出维度: [B, C, L] | 
				
			|||
    # NOTE:  | 
				
			|||
    1. 移除了分片(slicing)和平铺(tiling)功能 | 
				
			|||
    2. 直接使用mode()而不是sample()获取潜在向量 | 
				
			|||
    3. 简化了编码过程,只保留核心功能 | 
				
			|||
    4. 返回确定性的潜在向量而不是分布 | 
				
			|||
    ''' | 
				
			|||
class SDFHead(nn.Module): | 
				
			|||
    """SDF预测头""" | 
				
			|||
    def __init__(self, embed_dim: int = 768*2): | 
				
			|||
        super().__init__() | 
				
			|||
        self.layers_per_block = layers_per_block | 
				
			|||
 | 
				
			|||
        self.conv_in = nn.Conv1d( | 
				
			|||
            in_channels, | 
				
			|||
            block_out_channels[-1], | 
				
			|||
            kernel_size=3, | 
				
			|||
            stride=1, | 
				
			|||
            padding=1, | 
				
			|||
        self.mlp = nn.Sequential( | 
				
			|||
            nn.Linear(embed_dim, embed_dim//2), | 
				
			|||
            nn.LayerNorm(embed_dim//2), | 
				
			|||
            nn.ReLU(), | 
				
			|||
            nn.Linear(embed_dim//2, embed_dim//4), | 
				
			|||
            nn.LayerNorm(embed_dim//4), | 
				
			|||
            nn.ReLU(), | 
				
			|||
            nn.Linear(embed_dim//4, 1), | 
				
			|||
            nn.Tanh() | 
				
			|||
        ) | 
				
			|||
 | 
				
			|||
        self.mid_block = None | 
				
			|||
        self.up_blocks = nn.ModuleList([]) | 
				
			|||
 | 
				
			|||
        temb_channels = in_channels if norm_type == "spatial" else None | 
				
			|||
 | 
				
			|||
        # mid | 
				
			|||
        self.mid_block = UNetMidBlock1D( | 
				
			|||
            in_channels=block_out_channels[-1], | 
				
			|||
            mid_channels=block_out_channels[-1], | 
				
			|||
        ) | 
				
			|||
 | 
				
			|||
        # up | 
				
			|||
        reversed_block_out_channels = list(reversed(block_out_channels)) | 
				
			|||
        output_channel = reversed_block_out_channels[0] | 
				
			|||
        for i, up_block_type in enumerate(up_block_types): | 
				
			|||
            prev_output_channel = output_channel | 
				
			|||
            output_channel = reversed_block_out_channels[i] | 
				
			|||
 | 
				
			|||
            is_final_block = i == len(block_out_channels) - 1 | 
				
			|||
             | 
				
			|||
            up_block = UpBlock1D( | 
				
			|||
                in_channels=prev_output_channel, | 
				
			|||
                out_channels=output_channel, | 
				
			|||
            ) | 
				
			|||
            self.up_blocks.append(up_block) | 
				
			|||
            prev_output_channel = output_channel | 
				
			|||
 | 
				
			|||
        # out | 
				
			|||
        if norm_type == "spatial": | 
				
			|||
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) | 
				
			|||
        else: | 
				
			|||
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) | 
				
			|||
        self.conv_act = nn.SiLU() | 
				
			|||
        self.conv_out = nn.Conv1d(block_out_channels[0], out_channels, 3, padding=1) | 
				
			|||
 | 
				
			|||
        self.gradient_checkpointing = False | 
				
			|||
 | 
				
			|||
 | 
				
			|||
    def forward(self, z, latent_embeds=None): | 
				
			|||
        sample = z | 
				
			|||
        sample = self.conv_in(sample) | 
				
			|||
 | 
				
			|||
        if self.training and self.gradient_checkpointing: | 
				
			|||
 | 
				
			|||
            def create_custom_forward(module): | 
				
			|||
                def custom_forward(*inputs): | 
				
			|||
                    return module(*inputs) | 
				
			|||
 | 
				
			|||
                return custom_forward | 
				
			|||
 | 
				
			|||
            if is_torch_version(">=", "1.11.0"): | 
				
			|||
                # middle | 
				
			|||
                sample = torch.utils.checkpoint.checkpoint( | 
				
			|||
                    create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False | 
				
			|||
                ) | 
				
			|||
                # sample = sample.to(upscale_dtype) | 
				
			|||
 | 
				
			|||
                # up | 
				
			|||
                for up_block in self.up_blocks: | 
				
			|||
                    sample = torch.utils.checkpoint.checkpoint( | 
				
			|||
                        create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False | 
				
			|||
                    ) | 
				
			|||
            else: | 
				
			|||
                # middle | 
				
			|||
                sample = torch.utils.checkpoint.checkpoint( | 
				
			|||
                    create_custom_forward(self.mid_block), sample, latent_embeds | 
				
			|||
                ) | 
				
			|||
                # sample = sample.to(upscale_dtype) | 
				
			|||
 | 
				
			|||
                # up | 
				
			|||
                for up_block in self.up_blocks: | 
				
			|||
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) | 
				
			|||
        else: | 
				
			|||
            # middle | 
				
			|||
            sample = self.mid_block(sample, latent_embeds) | 
				
			|||
            # sample = sample.to(upscale_dtype) | 
				
			|||
            # up | 
				
			|||
            for up_block in self.up_blocks: | 
				
			|||
                sample = up_block(sample, latent_embeds) | 
				
			|||
         | 
				
			|||
        # post-process | 
				
			|||
        if latent_embeds is None: | 
				
			|||
            sample = self.conv_norm_out(sample) | 
				
			|||
        else: | 
				
			|||
            sample = self.conv_norm_out(sample, latent_embeds) | 
				
			|||
        sample = self.conv_act(sample) | 
				
			|||
        sample = self.conv_out(sample) | 
				
			|||
 | 
				
			|||
        return sample | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class SDFDecoder(nn.Module): | 
				
			|||
    def __init__( | 
				
			|||
        self, | 
				
			|||
        latent_size, | 
				
			|||
        dims, | 
				
			|||
        dropout=None, | 
				
			|||
        dropout_prob=0.0, | 
				
			|||
        norm_layers=(), | 
				
			|||
        latent_in=(), | 
				
			|||
        weight_norm=False, | 
				
			|||
        xyz_in_all=None, | 
				
			|||
        use_tanh=False, | 
				
			|||
        latent_dropout=False, | 
				
			|||
    ):   | 
				
			|||
        ''' | 
				
			|||
        这是第二阶段的解码器,用于生成SDF值 | 
				
			|||
        使用多层MLP结构 | 
				
			|||
        特点: | 
				
			|||
        支持在不同层注入latent信息(通过latent_in参数) | 
				
			|||
        可以在每层添加xyz坐标(通过xyz_in_all参数) | 
				
			|||
        支持权重归一化和dropout | 
				
			|||
        输入维度: [N, latent_size+3] | 
				
			|||
        输出维度: [N, 1] | 
				
			|||
         | 
				
			|||
        ''' | 
				
			|||
        super(SDFDecoder, self).__init__()  | 
				
			|||
 | 
				
			|||
        def make_sequence(): | 
				
			|||
            return [] | 
				
			|||
    def forward(self, x): | 
				
			|||
        return self.mlp(x) | 
				
			|||
         | 
				
			|||
        dims = [latent_size + 3] + dims + [1] | 
				
			|||
 | 
				
			|||
        self.num_layers = len(dims) | 
				
			|||
        self.norm_layers = norm_layers | 
				
			|||
        self.latent_in = latent_in | 
				
			|||
        self.latent_dropout = latent_dropout | 
				
			|||
        if self.latent_dropout: | 
				
			|||
            self.lat_dp = nn.Dropout(0.2) | 
				
			|||
 | 
				
			|||
        self.xyz_in_all = xyz_in_all | 
				
			|||
        self.weight_norm = weight_norm | 
				
			|||
 | 
				
			|||
        for layer in range(0, self.num_layers - 1): | 
				
			|||
            if layer + 1 in latent_in: | 
				
			|||
                out_dim = dims[layer + 1] - dims[0] | 
				
			|||
            else: | 
				
			|||
                out_dim = dims[layer + 1] | 
				
			|||
                if self.xyz_in_all and layer != self.num_layers - 2: | 
				
			|||
                    out_dim -= 3 | 
				
			|||
 | 
				
			|||
            if weight_norm and layer in self.norm_layers: | 
				
			|||
                setattr( | 
				
			|||
                    self, | 
				
			|||
                    "lin" + str(layer), | 
				
			|||
                    nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)), | 
				
			|||
                ) | 
				
			|||
            else: | 
				
			|||
                setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim)) | 
				
			|||
 | 
				
			|||
            if ( | 
				
			|||
                (not weight_norm) | 
				
			|||
                and self.norm_layers is not None | 
				
			|||
                and layer in self.norm_layers | 
				
			|||
            ): | 
				
			|||
                setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim)) | 
				
			|||
 | 
				
			|||
        self.use_tanh = use_tanh | 
				
			|||
        if use_tanh: | 
				
			|||
            self.tanh = nn.Tanh() | 
				
			|||
        self.relu = nn.ReLU() | 
				
			|||
 | 
				
			|||
        self.dropout_prob = dropout_prob | 
				
			|||
        self.dropout = dropout | 
				
			|||
        self.th = nn.Tanh() | 
				
			|||
 | 
				
			|||
    # input: N x (L+3) | 
				
			|||
    def forward(self, input): | 
				
			|||
        xyz = input[:, -3:] | 
				
			|||
 | 
				
			|||
        if input.shape[1] > 3 and self.latent_dropout: | 
				
			|||
            latent_vecs = input[:, :-3] | 
				
			|||
            latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training) | 
				
			|||
            x = torch.cat([latent_vecs, xyz], 1) | 
				
			|||
        else: | 
				
			|||
            x = input | 
				
			|||
 | 
				
			|||
        for layer in range(0, self.num_layers - 1): | 
				
			|||
            lin = getattr(self, "lin" + str(layer)) | 
				
			|||
            if layer in self.latent_in: | 
				
			|||
                x = torch.cat([x, input], 1) | 
				
			|||
            elif layer != 0 and self.xyz_in_all: | 
				
			|||
                x = torch.cat([x, xyz], 1) | 
				
			|||
            x = lin(x) | 
				
			|||
            # last layer Tanh | 
				
			|||
            if layer == self.num_layers - 2 and self.use_tanh: | 
				
			|||
                x = self.tanh(x) | 
				
			|||
            if layer < self.num_layers - 2: | 
				
			|||
                if ( | 
				
			|||
                    self.norm_layers is not None | 
				
			|||
                    and layer in self.norm_layers | 
				
			|||
                    and not self.weight_norm | 
				
			|||
                ): | 
				
			|||
                    bn = getattr(self, "bn" + str(layer)) | 
				
			|||
                    x = bn(x) | 
				
			|||
                x = self.relu(x) | 
				
			|||
                if self.dropout is not None and layer in self.dropout: | 
				
			|||
                    x = F.dropout(x, p=self.dropout_prob, training=self.training) | 
				
			|||
 | 
				
			|||
        if hasattr(self, "th"): | 
				
			|||
            x = self.th(x) | 
				
			|||
 | 
				
			|||
        return x | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class BRep2SdfDecoder(nn.Module): | 
				
			|||
    def __init__( | 
				
			|||
        self, | 
				
			|||
        latent_size=256, | 
				
			|||
        feature_dims=[512, 512, 256, 128],  # 特征解码器维度 | 
				
			|||
        sdf_dims=[512, 512, 512, 512],      # SDF解码器维度 | 
				
			|||
        up_block_types=("UpDecoderBlock2D",), | 
				
			|||
        layers_per_block=2, | 
				
			|||
        norm_num_groups=32, | 
				
			|||
        norm_type="group", | 
				
			|||
        dropout=None, | 
				
			|||
        dropout_prob=0.0, | 
				
			|||
        norm_layers=(), | 
				
			|||
        latent_in=(), | 
				
			|||
        weight_norm=False, | 
				
			|||
        xyz_in_all=True, | 
				
			|||
        use_tanh=True, | 
				
			|||
    ): | 
				
			|||
class SDFTransformer(nn.Module): | 
				
			|||
    """SDF Transformer编码器""" | 
				
			|||
    def __init__(self, embed_dim: int = 768, num_layers: int = 6): | 
				
			|||
        super().__init__() | 
				
			|||
         | 
				
			|||
        # 1. 特征解码器 (使用Decoder1D结构) | 
				
			|||
        self.feature_decoder = Decoder1D( | 
				
			|||
            in_channels=latent_size, | 
				
			|||
            out_channels=feature_dims[-1], | 
				
			|||
            up_block_types=up_block_types, | 
				
			|||
            block_out_channels=feature_dims, | 
				
			|||
            layers_per_block=layers_per_block, | 
				
			|||
            norm_num_groups=norm_num_groups, | 
				
			|||
            norm_type=norm_type | 
				
			|||
        layer = nn.TransformerEncoderLayer( | 
				
			|||
            d_model=embed_dim, | 
				
			|||
            nhead=8, | 
				
			|||
            dim_feedforward=1024, | 
				
			|||
            dropout=0.1, | 
				
			|||
            batch_first=True, | 
				
			|||
            norm_first=False  # 修改这里:设置为False | 
				
			|||
        ) | 
				
			|||
        self.transformer = nn.TransformerEncoder(layer, num_layers) | 
				
			|||
 | 
				
			|||
        # 2. SDF解码器 (使用原始Decoder结构) | 
				
			|||
        self.sdf_decoder = SDFDecoder( | 
				
			|||
            latent_size=feature_dims[-1],  # 使用特征解码器的输出维度 | 
				
			|||
            dims=sdf_dims, | 
				
			|||
            dropout=dropout, | 
				
			|||
            dropout_prob=dropout_prob, | 
				
			|||
            norm_layers=norm_layers, | 
				
			|||
            latent_in=latent_in, | 
				
			|||
            weight_norm=weight_norm, | 
				
			|||
            xyz_in_all=xyz_in_all, | 
				
			|||
            use_tanh=use_tanh, | 
				
			|||
            latent_dropout=False | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
        # 3. 特征转换层 (将特征解码器的输出转换为SDF解码器需要的格式) | 
				
			|||
        self.feature_transform = nn.Sequential( | 
				
			|||
            nn.Linear(feature_dims[-1], feature_dims[-1]), | 
				
			|||
            nn.LayerNorm(feature_dims[-1]), | 
				
			|||
            nn.SiLU() | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
    def forward(self, latent, query_points, latent_embeds=None): | 
				
			|||
        """ | 
				
			|||
        Args: | 
				
			|||
            latent: [B, C, L] B-rep特征 | 
				
			|||
            query_points: [B, N, 3] 查询点 | 
				
			|||
            latent_embeds: 可选的条件嵌入 | 
				
			|||
        Returns: | 
				
			|||
            sdf: [B, N, 1] SDF值 | 
				
			|||
        """ | 
				
			|||
        # 1. 特征解码 | 
				
			|||
        features = self.feature_decoder(latent, latent_embeds)  # [B, C, L] | 
				
			|||
         | 
				
			|||
        # 2. 特征转换 | 
				
			|||
        B, C, L = features.shape | 
				
			|||
        features = features.permute(0, 2, 1)  # [B, L, C] | 
				
			|||
        features = self.feature_transform(features)  # [B, L, C] | 
				
			|||
         | 
				
			|||
        # 3. 准备SDF解码器输入 | 
				
			|||
        _, N, _ = query_points.shape | 
				
			|||
        features = features.unsqueeze(1).expand(-1, N, -1, -1)  # [B, N, L, C] | 
				
			|||
        query_points = query_points.unsqueeze(2).expand(-1, -1, L, -1)  # [B, N, L, 3] | 
				
			|||
         | 
				
			|||
        # 4. 合并特征和坐标 | 
				
			|||
        sdf_input = torch.cat([ | 
				
			|||
            features.reshape(B*N*L, -1),  # [B*N*L, C] | 
				
			|||
            query_points.reshape(B*N*L, -1)  # [B*N*L, 3] | 
				
			|||
        ], dim=-1) | 
				
			|||
         | 
				
			|||
        # 5. SDF生成 | 
				
			|||
        sdf = self.sdf_decoder(sdf_input)  # [B*N*L, 1] | 
				
			|||
        sdf = sdf.reshape(B, N, L, 1)  # [B, N, L, 1] | 
				
			|||
         | 
				
			|||
        # 6. 聚合多尺度SDF | 
				
			|||
        sdf = sdf.mean(dim=2)  # [B, N, 1] | 
				
			|||
         | 
				
			|||
        return sdf | 
				
			|||
 | 
				
			|||
# 使用示例 | 
				
			|||
if __name__ == "__main__": | 
				
			|||
    # 创建模型 | 
				
			|||
    decoder = BRepDecoder( | 
				
			|||
        latent_size=256, | 
				
			|||
        feature_dims=[512, 256, 128, 64], | 
				
			|||
        sdf_dims=[512, 512, 512, 512], | 
				
			|||
        up_block_types=("UpDecoderBlock2D",), | 
				
			|||
        layers_per_block=2, | 
				
			|||
        norm_num_groups=32, | 
				
			|||
        dropout=None, | 
				
			|||
        dropout_prob=0.0, | 
				
			|||
        norm_layers=[0, 1, 2, 3], | 
				
			|||
        latent_in=[4], | 
				
			|||
        weight_norm=True, | 
				
			|||
        xyz_in_all=True, | 
				
			|||
        use_tanh=True | 
				
			|||
    ) | 
				
			|||
     | 
				
			|||
    # 测试数据 | 
				
			|||
    batch_size = 4 | 
				
			|||
    seq_len = 32 | 
				
			|||
    num_points = 1000 | 
				
			|||
     | 
				
			|||
    latent = torch.randn(batch_size, 256, seq_len) | 
				
			|||
    query_points = torch.randn(batch_size, num_points, 3) | 
				
			|||
    latent_embeds = torch.randn(batch_size, 256) | 
				
			|||
     | 
				
			|||
    # 前向传播 | 
				
			|||
    sdf = decoder(latent, query_points, latent_embeds) | 
				
			|||
    print(f"Input latent shape: {latent.shape}") | 
				
			|||
    print(f"Query points shape: {query_points.shape}") | 
				
			|||
    print(f"Output SDF shape: {sdf.shape}") | 
				
			|||
    def forward(self, x, mask=None): | 
				
			|||
        return self.transformer(x, src_key_padding_mask=mask) | 
				
			|||
 | 
				
			|||
@ -1,127 +1,202 @@ | 
				
			|||
import torch | 
				
			|||
import torch.nn as nn | 
				
			|||
from encoder import BRepEncoder | 
				
			|||
from decoder import BRep2SdfDecoder | 
				
			|||
 | 
				
			|||
class BRep2SDF(nn.Module): | 
				
			|||
    def __init__( | 
				
			|||
        self, | 
				
			|||
        # 编码器参数 | 
				
			|||
        in_channels=3, | 
				
			|||
        latent_size=256, | 
				
			|||
        encoder_block_out_channels=(512, 256, 128, 64), | 
				
			|||
        # 解码器参数 | 
				
			|||
        decoder_feature_dims=(512, 256, 128, 64), | 
				
			|||
        sdf_dims=(512, 512, 512, 512), | 
				
			|||
        # 共享参数 | 
				
			|||
        layers_per_block=2, | 
				
			|||
        norm_num_groups=32, | 
				
			|||
        # SDF特定参数 | 
				
			|||
        dropout=None, | 
				
			|||
        dropout_prob=0.0, | 
				
			|||
        norm_layers=(0, 1, 2, 3), | 
				
			|||
        latent_in=(4,), | 
				
			|||
        weight_norm=True, | 
				
			|||
        xyz_in_all=True, | 
				
			|||
        use_tanh=True, | 
				
			|||
    ): | 
				
			|||
import torch.nn.functional as F | 
				
			|||
from typing import Dict, Optional, Tuple, Union | 
				
			|||
from brep2sdf.config.default_config import get_default_config | 
				
			|||
from brep2sdf.utils.logger import logger | 
				
			|||
 | 
				
			|||
from brep2sdf.networks.encoder import BRepFeatureEmbedder | 
				
			|||
from brep2sdf.networks.decoder import SDFHead, SDFTransformer | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class BRepToSDF(nn.Module): | 
				
			|||
    def __init__(self, config=None): | 
				
			|||
        super().__init__() | 
				
			|||
        # 获取配置 | 
				
			|||
        if config is None: | 
				
			|||
            self.config = get_default_config() | 
				
			|||
        else: | 
				
			|||
            self.config = config | 
				
			|||
             | 
				
			|||
        # 1. 编码器配置 | 
				
			|||
        encoder_config = type('Config', (), { | 
				
			|||
            'in_channels': in_channels, | 
				
			|||
            'out_channels': latent_size, | 
				
			|||
            'block_out_channels': encoder_block_out_channels, | 
				
			|||
            'layers_per_block': layers_per_block, | 
				
			|||
            'norm_num_groups': norm_num_groups, | 
				
			|||
            'encoder_params': { | 
				
			|||
                'in_channels': in_channels, | 
				
			|||
                'out_channels': latent_size, | 
				
			|||
                'block_out_channels': encoder_block_out_channels, | 
				
			|||
                'layers_per_block': layers_per_block, | 
				
			|||
                'norm_num_groups': norm_num_groups, | 
				
			|||
            } | 
				
			|||
        })() | 
				
			|||
         | 
				
			|||
        # 2. 解码器配置 | 
				
			|||
        decoder_config = { | 
				
			|||
            'latent_size': latent_size, | 
				
			|||
            'feature_dims': decoder_feature_dims, | 
				
			|||
            'sdf_dims': sdf_dims, | 
				
			|||
            'layers_per_block': layers_per_block, | 
				
			|||
            'norm_num_groups': norm_num_groups, | 
				
			|||
            'dropout': dropout, | 
				
			|||
            'dropout_prob': dropout_prob, | 
				
			|||
            'norm_layers': norm_layers, | 
				
			|||
            'latent_in': latent_in, | 
				
			|||
            'weight_norm': weight_norm, | 
				
			|||
            'xyz_in_all': xyz_in_all, | 
				
			|||
            'use_tanh': use_tanh, | 
				
			|||
        } | 
				
			|||
         | 
				
			|||
        # 3. 创建编码器和解码器 | 
				
			|||
        self.encoder = BRepEncoder(encoder_config) | 
				
			|||
        self.decoder = BRep2SdfDecoder(**decoder_config) | 
				
			|||
         | 
				
			|||
    def encode(self, brep_model): | 
				
			|||
        """编码B-rep模型为潜在特征""" | 
				
			|||
        return self.encoder.encode(brep_model) | 
				
			|||
         | 
				
			|||
    def decode(self, latent, query_points, latent_embeds=None): | 
				
			|||
        """从潜在特征解码SDF值""" | 
				
			|||
        return self.decoder(latent, query_points, latent_embeds) | 
				
			|||
     | 
				
			|||
    def forward(self, brep_model, query_points): | 
				
			|||
        """完整的前向传播过程""" | 
				
			|||
        # 1. 编码B-rep模型 | 
				
			|||
        latent = self.encode(brep_model) | 
				
			|||
        if latent is None: | 
				
			|||
            return None | 
				
			|||
             | 
				
			|||
        # 2. 解码SDF值 | 
				
			|||
        sdf = self.decode(latent, query_points) | 
				
			|||
        return sdf | 
				
			|||
 | 
				
			|||
# 使用示例 | 
				
			|||
if __name__ == "__main__": | 
				
			|||
    # 创建模型 | 
				
			|||
    model = BRep2SDF( | 
				
			|||
        in_channels=3, | 
				
			|||
        latent_size=256, | 
				
			|||
        encoder_block_out_channels=(512, 256, 128, 64), | 
				
			|||
        decoder_feature_dims=(512, 256, 128, 64), | 
				
			|||
        sdf_dims=(512, 512, 512, 512), | 
				
			|||
        layers_per_block=2, | 
				
			|||
        norm_num_groups=32, | 
				
			|||
    ) | 
				
			|||
     | 
				
			|||
    # 测试数据 | 
				
			|||
    batch_size = 4 | 
				
			|||
    seq_len = 32 | 
				
			|||
    num_points = 1000 | 
				
			|||
     | 
				
			|||
    # 模拟B-rep模型数据 | 
				
			|||
    class MockBRep: | 
				
			|||
        def __init__(self): | 
				
			|||
            self.faces = [MockFace() for _ in range(10)] | 
				
			|||
            self.edges = [MockEdge() for _ in range(20)] | 
				
			|||
             | 
				
			|||
    class MockFace: | 
				
			|||
        def __init__(self): | 
				
			|||
            self.center_point = torch.randn(3) | 
				
			|||
            self.normal_vector = torch.randn(3) | 
				
			|||
            self.surface_type = 0 | 
				
			|||
            self.edges = [] | 
				
			|||
             | 
				
			|||
    class MockEdge: | 
				
			|||
        def __init__(self): | 
				
			|||
            self.length = lambda: 1.0 | 
				
			|||
            self.point_at = lambda t: torch.randn(3) | 
				
			|||
     | 
				
			|||
    brep_model = MockBRep() | 
				
			|||
    query_points = torch.randn(batch_size, num_points, 3) | 
				
			|||
        # 从配置中读取参数 | 
				
			|||
        self.embed_dim = self.config.model.embed_dim | 
				
			|||
        self.brep_feature_dim = self.config.model.brep_feature_dim | 
				
			|||
        self.latent_dim = self.config.model.latent_dim | 
				
			|||
        self.use_cf = self.config.model.use_cf | 
				
			|||
         | 
				
			|||
        # 1. 查询点编码器 | 
				
			|||
        self.query_encoder = nn.Sequential( | 
				
			|||
            nn.Linear(3, self.embed_dim//4), | 
				
			|||
            nn.LayerNorm(self.embed_dim//4), | 
				
			|||
            nn.ReLU(), | 
				
			|||
            nn.Linear(self.embed_dim//4, self.embed_dim//2), | 
				
			|||
            nn.LayerNorm(self.embed_dim//2), | 
				
			|||
            nn.ReLU(), | 
				
			|||
            nn.Linear(self.embed_dim//2, self.embed_dim) | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
        # 2. B-rep特征编码器 | 
				
			|||
        self.brep_embedder = BRepFeatureEmbedder() | 
				
			|||
         | 
				
			|||
        # 3. 特征融合Transformer | 
				
			|||
        self.transformer = SDFTransformer( | 
				
			|||
            embed_dim=self.embed_dim, | 
				
			|||
            num_layers=6  # 这个参数也可以移到配置文件中 | 
				
			|||
        ) | 
				
			|||
         | 
				
			|||
        # 4. SDF预测头 | 
				
			|||
        self.sdf_head = SDFHead(embed_dim=self.embed_dim*2) | 
				
			|||
 | 
				
			|||
    def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, query_points, data_class=None): | 
				
			|||
        """B-rep到SDF的前向传播 | 
				
			|||
         | 
				
			|||
        Args: | 
				
			|||
            edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] | 
				
			|||
            edge_pos: 边位置 [B, max_face, max_edge, 6] | 
				
			|||
            edge_mask: 边掩码 [B, max_face, max_edge] | 
				
			|||
            surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] | 
				
			|||
            surf_pos: 面位置 [B, max_face, 6] | 
				
			|||
            vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] | 
				
			|||
            query_points: 查询点 [B, num_queries, 3] | 
				
			|||
            data_class: (可选) 类别标签 | 
				
			|||
         | 
				
			|||
        Returns: | 
				
			|||
            sdf: 预测的SDF值 [B, num_queries, 1] | 
				
			|||
        """ | 
				
			|||
        B, Q = query_points.shape[:2]  # B: batch_size, Q: num_queries | 
				
			|||
         | 
				
			|||
        try: | 
				
			|||
             # 确保query_points需要梯度 | 
				
			|||
            if not query_points.requires_grad: | 
				
			|||
                query_points = query_points.detach().requires_grad_(True) | 
				
			|||
                 | 
				
			|||
        | 
				
			|||
            # 1. B-rep特征编码 | 
				
			|||
            brep_features = self.brep_embedder( | 
				
			|||
                edge_ncs=edge_ncs,         # [B, max_face, max_edge, num_edge_points, 3] | 
				
			|||
                edge_pos=edge_pos,         # [B, max_face, max_edge, 6] | 
				
			|||
                edge_mask=edge_mask,       # [B, max_face, max_edge] | 
				
			|||
                surf_ncs=surf_ncs,         # [B, max_face, num_surf_points, 3] | 
				
			|||
                surf_pos=surf_pos,         # [B, max_face, 6] | 
				
			|||
                vertex_pos=vertex_pos,     # [B, max_face, max_edge, 2, 3] | 
				
			|||
                data_class=data_class | 
				
			|||
            )  # [B, max_face*(max_edge+1), embed_dim] | 
				
			|||
             | 
				
			|||
            # 2. 查询点编码 | 
				
			|||
            query_features = self.query_encoder(query_points)  # [B, Q, embed_dim] | 
				
			|||
             | 
				
			|||
            # 3. 提取全局特征 | 
				
			|||
            global_features = brep_features.mean(dim=1)  # [B, embed_dim] | 
				
			|||
             | 
				
			|||
            # 4. 为每个查询点准备特征 | 
				
			|||
            expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1)  # [B, Q, embed_dim] | 
				
			|||
             | 
				
			|||
            # 5. 连接查询点特征和全局特征 | 
				
			|||
            combined_features = torch.cat([ | 
				
			|||
                expanded_features,  # [B, Q, embed_dim] | 
				
			|||
                query_features     # [B, Q, embed_dim] | 
				
			|||
            ], dim=-1)  # [B, Q, embed_dim*2] | 
				
			|||
             | 
				
			|||
            # 6. SDF预测 | 
				
			|||
            sdf = self.sdf_head(combined_features)  # [B, Q, 1] | 
				
			|||
 | 
				
			|||
            if not sdf.requires_grad: | 
				
			|||
                logger.warning("SDF output does not require grad!") | 
				
			|||
           | 
				
			|||
             | 
				
			|||
            return sdf | 
				
			|||
             | 
				
			|||
        except Exception as e: | 
				
			|||
            logger.error(f"Error in BRepToSDF forward pass:") | 
				
			|||
            logger.error(f"  Error message: {str(e)}") | 
				
			|||
            logger.error(f"  Input shapes:") | 
				
			|||
            logger.error(f"    edge_ncs: {edge_ncs.shape}") | 
				
			|||
            logger.error(f"    edge_pos: {edge_pos.shape}") | 
				
			|||
            logger.error(f"    edge_mask: {edge_mask.shape}") | 
				
			|||
            logger.error(f"    surf_ncs: {surf_ncs.shape}") | 
				
			|||
            logger.error(f"    surf_pos: {surf_pos.shape}") | 
				
			|||
            logger.error(f"    vertex_pos: {vertex_pos.shape}") | 
				
			|||
            logger.error(f"    query_points: {query_points.shape}") | 
				
			|||
            raise | 
				
			|||
 | 
				
			|||
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): | 
				
			|||
    """SDF损失函数""" | 
				
			|||
    # 确保points需要梯度 | 
				
			|||
    if not points.requires_grad: | 
				
			|||
        points = points.detach().requires_grad_(True) | 
				
			|||
     | 
				
			|||
    # L1损失 | 
				
			|||
    l1_loss = F.l1_loss(pred_sdf, gt_sdf) | 
				
			|||
     | 
				
			|||
    try: | 
				
			|||
        # 梯度约束损失 | 
				
			|||
        grad = torch.autograd.grad( | 
				
			|||
            pred_sdf.sum(),  | 
				
			|||
            points, | 
				
			|||
            create_graph=True, | 
				
			|||
            retain_graph=True, | 
				
			|||
            allow_unused=True | 
				
			|||
        )[0] | 
				
			|||
         | 
				
			|||
        if grad is not None: | 
				
			|||
            grad_constraint = F.mse_loss( | 
				
			|||
                torch.norm(grad, dim=-1), | 
				
			|||
                torch.ones_like(pred_sdf.squeeze(-1)) | 
				
			|||
            ) | 
				
			|||
        else: | 
				
			|||
            grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			|||
             | 
				
			|||
    except Exception as e: | 
				
			|||
        logger.warning(f"Gradient computation failed: {str(e)}") | 
				
			|||
        grad_constraint = torch.tensor(0.0, device=pred_sdf.device) | 
				
			|||
     | 
				
			|||
    return l1_loss + grad_weight * grad_constraint | 
				
			|||
 | 
				
			|||
def main(): | 
				
			|||
    # 获取配置 | 
				
			|||
    config = get_default_config() | 
				
			|||
     | 
				
			|||
    # 初始化模型 | 
				
			|||
    model = BRepToSDF(config=config) | 
				
			|||
     | 
				
			|||
    # 从配置获取参数 | 
				
			|||
    batch_size = config.train.batch_size | 
				
			|||
    max_face = config.data.max_face | 
				
			|||
    max_edge = config.data.max_edge | 
				
			|||
    num_surf_points = config.model.num_surf_points | 
				
			|||
    num_edge_points = config.model.num_edge_points | 
				
			|||
     | 
				
			|||
    # 生成测试数据 | 
				
			|||
    test_data = { | 
				
			|||
        'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), | 
				
			|||
        'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), | 
				
			|||
        'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), | 
				
			|||
        'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), | 
				
			|||
        'surf_pos': torch.randn(batch_size, max_face, 6), | 
				
			|||
        'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), | 
				
			|||
        'query_points': torch.randn(batch_size, 1000, 3)  # 1000个查询点 | 
				
			|||
    } | 
				
			|||
     | 
				
			|||
    # 打印输入数据形状 | 
				
			|||
    logger.info("Input shapes:") | 
				
			|||
    for name, tensor in test_data.items(): | 
				
			|||
        logger.info(f"  {name}: {tensor.shape}") | 
				
			|||
     | 
				
			|||
    # 前向传播 | 
				
			|||
    sdf = model(brep_model, query_points) | 
				
			|||
    if sdf is not None: | 
				
			|||
        print(f"Output SDF shape: {sdf.shape}") | 
				
			|||
    try: | 
				
			|||
        sdf = model(**test_data) | 
				
			|||
        logger.info(f"\nOutput SDF shape: {sdf.shape}") | 
				
			|||
         | 
				
			|||
        # 计算模型参数量 | 
				
			|||
        total_params = sum(p.numel() for p in model.parameters()) | 
				
			|||
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | 
				
			|||
        logger.info(f"\nModel statistics:") | 
				
			|||
        logger.info(f"  Total parameters: {total_params:,}") | 
				
			|||
        logger.info(f"  Trainable parameters: {trainable_params:,}") | 
				
			|||
         | 
				
			|||
    except Exception as e: | 
				
			|||
        logger.error(f"Error during forward pass: {str(e)}") | 
				
			|||
        raise | 
				
			|||
 | 
				
			|||
if __name__ == "__main__": | 
				
			|||
    main() | 
				
			|||
					Loading…
					
					
				
		Reference in new issue