1 changed files with 313 additions and 311 deletions
@ -1,356 +1,358 @@ |
|||
import math |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
from dataclasses import dataclass |
|||
from typing import Dict, Optional, Tuple, Union |
|||
|
|||
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|||
from diffusers.utils import BaseOutput, is_torch_version |
|||
from diffusers.utils.accelerate_utils import apply_forward_hook |
|||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor |
|||
from diffusers.models.modeling_utils import ModelMixin |
|||
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
|||
from diffusers.models.unets.unet_1d_blocks import ResConvBlock, SelfAttention1d, get_down_block, get_up_block, Upsample1d |
|||
from diffusers.models.attention_processor import SpatialNorm |
|||
|
|||
''' |
|||
# NOTE: |
|||
移除了分片(slicing)和平铺(tiling)功能 |
|||
直接使用mode()而不是sample()获取潜在向量 |
|||
简化了编码过程,只保留核心功能 |
|||
返回确定性的潜在向量而不是分布 |
|||
''' |
|||
|
|||
# 1. 基础网络组件 |
|||
class Embedder(nn.Module): |
|||
def __init__(self, vocab_size, d_model): |
|||
class ResConvBlock(nn.Module): |
|||
"""残差卷积块""" |
|||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int): |
|||
super().__init__() |
|||
self.embed = nn.Embedding(vocab_size, d_model) |
|||
self._init_embeddings() |
|||
|
|||
def _init_embeddings(self): |
|||
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") |
|||
self.conv1 = nn.Conv1d(in_channels, mid_channels, 3, padding=1) |
|||
self.norm1 = nn.GroupNorm(32, mid_channels) |
|||
self.conv2 = nn.Conv1d(mid_channels, out_channels, 3, padding=1) |
|||
self.norm2 = nn.GroupNorm(32, out_channels) |
|||
self.act = nn.SiLU() |
|||
|
|||
self.conv_shortcut = None |
|||
if in_channels != out_channels: |
|||
self.conv_shortcut = nn.Conv1d(in_channels, out_channels, 1) |
|||
|
|||
def forward(self, x): |
|||
return self.embed(x) |
|||
|
|||
residual = x |
|||
x = self.conv1(x) |
|||
x = self.norm1(x) |
|||
x = self.act(x) |
|||
x = self.conv2(x) |
|||
x = self.norm2(x) |
|||
|
|||
if self.conv_shortcut is not None: |
|||
residual = self.conv_shortcut(residual) |
|||
|
|||
return self.act(x + residual) |
|||
|
|||
class UpBlock1D(nn.Module): |
|||
def __init__(self, in_channels, out_channels, mid_channels=None): |
|||
class SelfAttention1d(nn.Module): |
|||
"""一维自注意力层""" |
|||
def __init__(self, channels: int, num_head_channels: int): |
|||
super().__init__() |
|||
mid_channels = in_channels if mid_channels is None else mid_channels |
|||
|
|||
resnets = [ |
|||
ResConvBlock(in_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, out_channels), |
|||
] |
|||
|
|||
self.resnets = nn.ModuleList(resnets) |
|||
self.up = Upsample1d(kernel="cubic") |
|||
|
|||
def forward(self, hidden_states, temb=None): |
|||
for resnet in self.resnets: |
|||
hidden_states = resnet(hidden_states) |
|||
hidden_states = self.up(hidden_states) |
|||
return hidden_states |
|||
self.num_heads = channels // num_head_channels |
|||
self.scale = num_head_channels ** -0.5 |
|||
|
|||
self.qkv = nn.Conv1d(channels, channels * 3, 1) |
|||
self.proj = nn.Conv1d(channels, channels, 1) |
|||
|
|||
def forward(self, x): |
|||
b, c, l = x.shape |
|||
qkv = self.qkv(x).reshape(b, 3, self.num_heads, c // self.num_heads, l) |
|||
q, k, v = qkv.unbind(1) |
|||
|
|||
attn = (q @ k.transpose(-2, -1)) * self.scale |
|||
attn = attn.softmax(dim=-1) |
|||
|
|||
x = (attn @ v).reshape(b, c, l) |
|||
x = self.proj(x) |
|||
return x |
|||
|
|||
class UNetMidBlock1D(nn.Module): |
|||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): |
|||
"""U-Net中间块""" |
|||
def __init__(self, in_channels: int, mid_channels: int): |
|||
super().__init__() |
|||
|
|||
out_channels = in_channels if out_channels is None else out_channels |
|||
|
|||
# there is always at least one resnet |
|||
resnets = [ |
|||
self.resnets = nn.ModuleList([ |
|||
ResConvBlock(in_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, mid_channels), |
|||
ResConvBlock(mid_channels, mid_channels, out_channels), |
|||
] |
|||
attentions = [ |
|||
SelfAttention1d(mid_channels, mid_channels // 32), |
|||
SelfAttention1d(mid_channels, mid_channels // 32), |
|||
SelfAttention1d(mid_channels, mid_channels // 32), |
|||
SelfAttention1d(mid_channels, mid_channels // 32), |
|||
SelfAttention1d(mid_channels, mid_channels // 32), |
|||
SelfAttention1d(out_channels, out_channels // 32), |
|||
] |
|||
|
|||
self.attentions = nn.ModuleList(attentions) |
|||
self.resnets = nn.ModuleList(resnets) |
|||
ResConvBlock(mid_channels, mid_channels, in_channels), |
|||
]) |
|||
self.attentions = nn.ModuleList([ |
|||
SelfAttention1d(mid_channels, mid_channels // 32) |
|||
for _ in range(3) |
|||
]) |
|||
|
|||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: |
|||
def forward(self, x): |
|||
for attn, resnet in zip(self.attentions, self.resnets): |
|||
hidden_states = resnet(hidden_states) |
|||
hidden_states = attn(hidden_states) |
|||
|
|||
return hidden_states |
|||
|
|||
x = resnet(x) |
|||
x = attn(x) |
|||
return x |
|||
|
|||
class Encoder1D(nn.Module): |
|||
"""一维编码器""" |
|||
def __init__( |
|||
self, |
|||
in_channels=3, |
|||
out_channels=3, |
|||
down_block_types=("DownEncoderBlock1D",), |
|||
block_out_channels=(64,), |
|||
layers_per_block=2, |
|||
norm_num_groups=32, |
|||
act_fn="silu", |
|||
double_z=True, |
|||
in_channels: int = 3, |
|||
out_channels: int = 256, |
|||
block_out_channels: Tuple[int] = (64, 128, 256), |
|||
layers_per_block: int = 2, |
|||
): |
|||
super().__init__() |
|||
self.layers_per_block = layers_per_block |
|||
|
|||
self.conv_in = torch.nn.Conv1d( |
|||
in_channels, |
|||
block_out_channels[0], |
|||
kernel_size=3, |
|||
stride=1, |
|||
padding=1, |
|||
) |
|||
|
|||
self.mid_block = None |
|||
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) |
|||
|
|||
self.down_blocks = nn.ModuleList([]) |
|||
|
|||
# down |
|||
output_channel = block_out_channels[0] |
|||
for i, down_block_type in enumerate(down_block_types): |
|||
input_channel = output_channel |
|||
output_channel = block_out_channels[i] |
|||
is_final_block = i == len(block_out_channels) - 1 |
|||
|
|||
down_block = get_down_block( |
|||
down_block_type, |
|||
num_layers=self.layers_per_block, |
|||
in_channels=input_channel, |
|||
out_channels=output_channel, |
|||
add_downsample=not is_final_block, |
|||
temb_channels=None, |
|||
) |
|||
self.down_blocks.append(down_block) |
|||
|
|||
# mid |
|||
in_ch = block_out_channels[0] |
|||
for out_ch in block_out_channels: |
|||
block = [] |
|||
for _ in range(layers_per_block): |
|||
block.append(ResConvBlock(in_ch, out_ch, out_ch)) |
|||
in_ch = out_ch |
|||
if out_ch != block_out_channels[-1]: |
|||
block.append(nn.AvgPool1d(2)) |
|||
self.down_blocks.append(nn.Sequential(*block)) |
|||
|
|||
self.mid_block = UNetMidBlock1D( |
|||
in_channels=block_out_channels[-1], |
|||
mid_channels=block_out_channels[-1], |
|||
) |
|||
|
|||
# out |
|||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) |
|||
self.conv_act = nn.SiLU() |
|||
|
|||
conv_out_channels = 2 * out_channels if double_z else out_channels |
|||
self.conv_out = nn.Conv1d(block_out_channels[-1], conv_out_channels, 3, padding=1) |
|||
|
|||
self.gradient_checkpointing = False |
|||
self.conv_out = nn.Sequential( |
|||
nn.GroupNorm(32, block_out_channels[-1]), |
|||
nn.SiLU(), |
|||
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1), |
|||
) |
|||
|
|||
def forward(self, x): |
|||
sample = x |
|||
sample = self.conv_in(sample) |
|||
x = self.conv_in(x) |
|||
for block in self.down_blocks: |
|||
x = block(x) |
|||
x = self.mid_block(x) |
|||
x = self.conv_out(x) |
|||
return x |
|||
|
|||
class BRepFeatureEmbedder(nn.Module): |
|||
"""B-rep特征嵌入器""" |
|||
def __init__(self, use_cf: bool = True): |
|||
super().__init__() |
|||
self.embed_dim = 768 |
|||
self.use_cf = use_cf |
|||
|
|||
if self.training and self.gradient_checkpointing: |
|||
|
|||
def create_custom_forward(module): |
|||
def custom_forward(*inputs): |
|||
return module(*inputs) |
|||
|
|||
return custom_forward |
|||
|
|||
# down |
|||
if is_torch_version(">=", "1.11.0"): |
|||
for down_block in self.down_blocks: |
|||
sample = torch.utils.checkpoint.checkpoint( |
|||
create_custom_forward(down_block), sample, use_reentrant=False |
|||
) |
|||
|
|||
# middle |
|||
sample = torch.utils.checkpoint.checkpoint( |
|||
create_custom_forward(self.mid_block), sample, use_reentrant=False |
|||
) |
|||
else: |
|||
for down_block in self.down_blocks: |
|||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) |
|||
# middle |
|||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) |
|||
layer = nn.TransformerEncoderLayer( |
|||
d_model=self.embed_dim, |
|||
nhead=12, |
|||
norm_first=False, |
|||
dim_feedforward=1024, |
|||
dropout=0.1 |
|||
) |
|||
self.transformer = nn.TransformerEncoder( |
|||
layer, |
|||
num_layers=12, |
|||
norm=nn.LayerNorm(self.embed_dim), |
|||
enable_nested_tensor=False # 添加这个参数 |
|||
) |
|||
|
|||
self.surfz_embed = nn.Sequential( |
|||
nn.Linear(3*16, self.embed_dim), |
|||
nn.LayerNorm(self.embed_dim), |
|||
nn.SiLU(), |
|||
nn.Linear(self.embed_dim, self.embed_dim), |
|||
) |
|||
|
|||
self.edgez_embed = nn.Sequential( |
|||
nn.Linear(3*4, self.embed_dim), |
|||
nn.LayerNorm(self.embed_dim), |
|||
nn.SiLU(), |
|||
nn.Linear(self.embed_dim, self.embed_dim), |
|||
) |
|||
|
|||
self.surfp_embed = nn.Sequential( |
|||
nn.Linear(6, self.embed_dim), |
|||
nn.LayerNorm(self.embed_dim), |
|||
nn.SiLU(), |
|||
nn.Linear(self.embed_dim, self.embed_dim), |
|||
) |
|||
|
|||
self.edgep_embed = nn.Sequential( |
|||
nn.Linear(6, self.embed_dim), |
|||
nn.LayerNorm(self.embed_dim), |
|||
nn.SiLU(), |
|||
nn.Linear(self.embed_dim, self.embed_dim), |
|||
) |
|||
|
|||
self.vertp_embed = nn.Sequential( |
|||
nn.Linear(6, self.embed_dim), |
|||
nn.LayerNorm(self.embed_dim), |
|||
nn.SiLU(), |
|||
nn.Linear(self.embed_dim, self.embed_dim), |
|||
) |
|||
|
|||
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, mask=None): |
|||
# 特征嵌入 |
|||
surf_embeds = self.surfz_embed(surf_z) |
|||
edge_embeds = self.edgez_embed(edge_z) |
|||
|
|||
# 点嵌入 |
|||
surf_p_embeds = self.surfp_embed(surf_p) |
|||
edge_p_embeds = self.edgep_embed(edge_p) |
|||
vert_p_embeds = self.vertp_embed(vert_p) |
|||
|
|||
# 组合所有嵌入 |
|||
if self.use_cf: |
|||
embeds = torch.cat([ |
|||
surf_embeds + surf_p_embeds, |
|||
edge_embeds + edge_p_embeds, |
|||
vert_p_embeds |
|||
], dim=1) |
|||
else: |
|||
# down |
|||
for down_block in self.down_blocks: |
|||
sample = down_block(sample)[0] |
|||
embeds = torch.cat([ |
|||
surf_p_embeds, |
|||
edge_p_embeds, |
|||
vert_p_embeds |
|||
], dim=1) |
|||
|
|||
# middle |
|||
sample = self.mid_block(sample) |
|||
|
|||
# post-process |
|||
sample = self.conv_norm_out(sample) |
|||
sample = self.conv_act(sample) |
|||
sample = self.conv_out(sample) |
|||
return sample |
|||
output = self.transformer(embeds, src_key_padding_mask=mask) |
|||
return output |
|||
|
|||
# 2. B-rep特征处理 |
|||
class BRepFeatureExtractor: |
|||
def __init__(self, config): |
|||
self.encoder = Encoder1D( |
|||
in_channels=config.in_channels, # 根据特征维度设置 |
|||
out_channels=config.out_channels, |
|||
block_out_channels=config.block_out_channels, |
|||
layers_per_block=config.layers_per_block |
|||
class SDFTransformer(nn.Module): |
|||
"""SDF Transformer编码器""" |
|||
def __init__(self, embed_dim: int = 768, num_layers: int = 6): |
|||
super().__init__() |
|||
layer = nn.TransformerEncoderLayer( |
|||
d_model=embed_dim, |
|||
nhead=8, |
|||
dim_feedforward=1024, |
|||
dropout=0.1, |
|||
batch_first=True, |
|||
norm_first=False # 修改这里:设置为False |
|||
) |
|||
|
|||
def extract_face_features(self, face): |
|||
"""提取面的特征""" |
|||
features = [] |
|||
try: |
|||
# 基本几何特征 |
|||
center = face.center_point |
|||
normal = face.normal_vector |
|||
|
|||
# 边界特征 |
|||
bounds = face.bounds() |
|||
|
|||
# 曲面类型特征 |
|||
surface_type = face.surface_type |
|||
|
|||
# 组合特征 |
|||
feature = np.concatenate([ |
|||
center, # [3] |
|||
normal, # [3] |
|||
bounds.flatten(), |
|||
[surface_type] # 可以用one-hot编码 |
|||
]) |
|||
features.append(feature) |
|||
|
|||
except Exception as e: |
|||
print(f"Error extracting face features: {e}") |
|||
|
|||
return np.array(features) |
|||
self.transformer = nn.TransformerEncoder(layer, num_layers) |
|||
|
|||
def extract_edge_features(self, edge): |
|||
"""提取边的特征""" |
|||
features = [] |
|||
try: |
|||
# 采样点 |
|||
points = self.sample_points_on_edge(edge) |
|||
|
|||
for point in points: |
|||
# 位置 |
|||
pos = point.coordinates |
|||
# 切向量 |
|||
tangent = point.tangent |
|||
# 曲率 |
|||
curvature = point.curvature |
|||
|
|||
point_feature = np.concatenate([ |
|||
pos, # [3] |
|||
tangent, # [3] |
|||
[curvature] # [1] |
|||
]) |
|||
features.append(point_feature) |
|||
|
|||
except Exception as e: |
|||
print(f"Error extracting edge features: {e}") |
|||
|
|||
return np.array(features) |
|||
def forward(self, x, mask=None): |
|||
return self.transformer(x, src_key_padding_mask=mask) |
|||
|
|||
@staticmethod |
|||
def sample_points_on_edge(edge, num_points=32): |
|||
"""在边上均匀采样点""" |
|||
points = [] |
|||
try: |
|||
length = edge.length() |
|||
for i in range(num_points): |
|||
t = i / (num_points - 1) |
|||
point = edge.point_at(t * length) |
|||
points.append(point) |
|||
except Exception as e: |
|||
print(f"Error sampling points: {e}") |
|||
return points |
|||
class SDFHead(nn.Module): |
|||
"""SDF预测头""" |
|||
def __init__(self, embed_dim: int = 768*2): |
|||
super().__init__() |
|||
self.mlp = nn.Sequential( |
|||
nn.Linear(embed_dim, embed_dim//2), |
|||
nn.LayerNorm(embed_dim//2), |
|||
nn.ReLU(), |
|||
nn.Linear(embed_dim//2, embed_dim//4), |
|||
nn.LayerNorm(embed_dim//4), |
|||
nn.ReLU(), |
|||
nn.Linear(embed_dim//4, 1), |
|||
nn.Tanh() |
|||
) |
|||
|
|||
class BRepDataProcessor: |
|||
def __init__(self, feature_extractor): |
|||
self.feature_extractor = feature_extractor |
|||
|
|||
def process_brep(self, brep_model): |
|||
"""处理单个B-rep模型""" |
|||
try: |
|||
# 1. 提取面特征 |
|||
face_features = [] |
|||
for face in brep_model.faces: |
|||
feat = self.feature_extractor.extract_face_features(face) |
|||
face_features.append(feat) |
|||
|
|||
# 2. 提取边特征 |
|||
edge_features = [] |
|||
for edge in brep_model.edges: |
|||
feat = self.feature_extractor.extract_edge_features(edge) |
|||
edge_features.append(feat) |
|||
|
|||
# 3. 组织数据结构 |
|||
return { |
|||
'face_features': torch.tensor(face_features), |
|||
'edge_features': torch.tensor(edge_features), |
|||
'topology': self.extract_topology(brep_model) |
|||
} |
|||
|
|||
except Exception as e: |
|||
print(f"Error processing B-rep: {e}") |
|||
return None |
|||
|
|||
def extract_topology(self, brep_model): |
|||
"""提取拓扑关系""" |
|||
# 面-边关系矩阵 |
|||
face_edge_adj = np.zeros((len(brep_model.faces), len(brep_model.edges))) |
|||
# 填充邻接关系 |
|||
for i, face in enumerate(brep_model.faces): |
|||
for j, edge in enumerate(brep_model.edges): |
|||
if edge in face.edges: |
|||
face_edge_adj[i,j] = 1 |
|||
return face_edge_adj |
|||
def forward(self, x): |
|||
return self.mlp(x) |
|||
|
|||
# 3. 主编码器 |
|||
class BRepEncoder: |
|||
def __init__(self, config): |
|||
self.processor = BRepDataProcessor( |
|||
BRepFeatureExtractor(config) |
|||
class BRepToSDF(nn.Module): |
|||
def __init__( |
|||
self, |
|||
brep_feature_dim: int = 48, |
|||
use_cf: bool = True, |
|||
embed_dim: int = 768, |
|||
latent_dim: int = 256 |
|||
): |
|||
super().__init__() |
|||
self.embed_dim = embed_dim |
|||
|
|||
# 1. 查询点编码器 |
|||
self.query_encoder = nn.Sequential( |
|||
nn.Linear(3, embed_dim//4), |
|||
nn.LayerNorm(embed_dim//4), |
|||
nn.ReLU(), |
|||
nn.Linear(embed_dim//4, embed_dim//2), |
|||
nn.LayerNorm(embed_dim//2), |
|||
nn.ReLU(), |
|||
nn.Linear(embed_dim//2, embed_dim) |
|||
) |
|||
self.encoder = Encoder1D(**config.encoder_params) |
|||
|
|||
def encode(self, brep_model): |
|||
"""编码B-rep模型""" |
|||
try: |
|||
# 1. 处理原始数据 |
|||
processed_data = self.processor.process_brep(brep_model) |
|||
if processed_data is None: |
|||
return None |
|||
|
|||
# 2. 特征编码 |
|||
face_features = self.encoder(processed_data['face_features']) |
|||
edge_features = self.encoder(processed_data['edge_features']) |
|||
|
|||
# 3. 组合特征 |
|||
combined_features = self.combine_features( |
|||
face_features, |
|||
edge_features, |
|||
processed_data['topology'] |
|||
) |
|||
|
|||
return combined_features |
|||
|
|||
except Exception as e: |
|||
print(f"Error encoding B-rep: {e}") |
|||
return None |
|||
|
|||
def combine_features(self, face_features, edge_features, topology): |
|||
"""组合不同类型的特征""" |
|||
# 可以使用图神经网络或者注意力机制来组合特征 |
|||
combined = torch.cat([ |
|||
face_features.mean(dim=1), # 全局面特征 |
|||
edge_features.mean(dim=1), # 全局边特征 |
|||
topology.flatten() # 拓扑信息 |
|||
], dim=-1) |
|||
return combined |
|||
# 2. B-rep特征编码器 |
|||
self.feature_embedder = BRepFeatureEmbedder(use_cf=use_cf) |
|||
|
|||
# 3. 特征融合Transformer |
|||
self.transformer = SDFTransformer( |
|||
embed_dim=embed_dim, |
|||
num_layers=6 |
|||
) |
|||
|
|||
# 4. SDF预测头 |
|||
self.sdf_head = SDFHead(embed_dim=embed_dim*2) |
|||
|
|||
def forward(self, surf_z, edge_z, surf_p, edge_p, vert_p, query_points, mask=None): |
|||
""" |
|||
Args: |
|||
surf_z: 表面特征 [B, N, 48] |
|||
edge_z: 边特征 [B, M, 12] |
|||
surf_p: 表面点 [B, N, 6] |
|||
edge_p: 边点 [B, M, 6] |
|||
vert_p: 顶点点 [B, K, 6] |
|||
query_points: 查询点 [B, Q, 3] |
|||
mask: 注意力掩码 |
|||
Returns: |
|||
sdf: [B, Q, 1] |
|||
""" |
|||
B, Q, _ = query_points.shape |
|||
|
|||
# 1. B-rep特征嵌入 |
|||
brep_features = self.feature_embedder( |
|||
surf_z, edge_z, surf_p, edge_p, vert_p, mask |
|||
) # [B, N+M+K, embed_dim] |
|||
|
|||
# 2. 查询点编码 |
|||
query_features = self.query_encoder(query_points) # [B, Q, embed_dim] |
|||
|
|||
# 3. 提取全局特征 |
|||
global_features = brep_features.mean(dim=1) # [B, embed_dim] |
|||
|
|||
# 4. 为每个查询点准备特征 |
|||
expanded_features = global_features.unsqueeze(1).expand(-1, Q, -1) # [B, Q, embed_dim] |
|||
|
|||
# 5. 连接查询点特征和全局特征 |
|||
combined_features = torch.cat([ |
|||
expanded_features, # [B, Q, embed_dim] |
|||
query_features # [B, Q, embed_dim] |
|||
], dim=-1) # [B, Q, embed_dim*2] |
|||
|
|||
# 6. SDF预测 |
|||
sdf = self.sdf_head(combined_features) # [B, Q, 1] |
|||
|
|||
return sdf |
|||
|
|||
def sdf_loss(pred_sdf, gt_sdf, points, grad_weight: float = 0.1): |
|||
"""SDF损失函数""" |
|||
# L1损失 |
|||
l1_loss = F.l1_loss(pred_sdf, gt_sdf) |
|||
|
|||
# 梯度约束损失 |
|||
grad = torch.autograd.grad( |
|||
pred_sdf.sum(), |
|||
points, |
|||
create_graph=True |
|||
)[0] |
|||
grad_constraint = F.mse_loss( |
|||
torch.norm(grad, dim=-1), |
|||
torch.ones_like(pred_sdf.squeeze(-1)) |
|||
) |
|||
|
|||
return l1_loss + grad_weight * grad_constraint |
|||
|
|||
def main(): |
|||
# 初始化模型 |
|||
model = BRepToSDF( |
|||
brep_feature_dim=48, |
|||
use_cf=True, |
|||
embed_dim=768, |
|||
latent_dim=256 |
|||
) |
|||
|
|||
# 示例输入 |
|||
batch_size = 4 |
|||
num_surfs = 10 |
|||
num_edges = 20 |
|||
num_verts = 8 |
|||
num_queries = 1000 |
|||
|
|||
# 生成示例数据 |
|||
surf_z = torch.randn(batch_size, num_surfs, 48) |
|||
edge_z = torch.randn(batch_size, num_edges, 12) |
|||
surf_p = torch.randn(batch_size, num_surfs, 6) |
|||
edge_p = torch.randn(batch_size, num_edges, 6) |
|||
vert_p = torch.randn(batch_size, num_verts, 6) |
|||
query_points = torch.randn(batch_size, num_queries, 3) |
|||
|
|||
# 前向传播 |
|||
sdf = model(surf_z, edge_z, surf_p, edge_p, vert_p, query_points) |
|||
print(f"Output SDF shape: {sdf.shape}") |
|||
|
|||
if __name__ == "__main__": |
|||
main() |
Loading…
Reference in new issue