Browse Source

fix: can run, but still no grad

main
mckay 6 months ago
parent
commit
fa81eefa54
  1. 249
      brep2sdf/networks/encoder.py

249
brep2sdf/networks/encoder.py

@ -76,7 +76,7 @@ class UNetMidBlock1D(nn.Module):
x = attn(x)
return x
class Encoder1D(nn.Module):
class _Encoder1D(nn.Module):
def __init__(
self,
in_channels: int = 3,
@ -114,47 +114,40 @@ class Encoder1D(nn.Module):
x = self.mid_block(x)
x = self.conv_out(x)
return x
class Encoder1D_(nn.Module):
"""一维编码器"""
class Encoder1D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 256,
block_out_channels: Tuple[int] = (64, 128, 256),
layers_per_block: int = 2,
num_points: int = 2, # 输入点数
in_channels: int = 3, # xyz坐标
out_dim: int = 6, # 展平后维度 (2*3)
):
super().__init__()
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1)
self.down_blocks = nn.ModuleList([])
in_ch = block_out_channels[0]
for out_ch in block_out_channels:
block = []
for _ in range(layers_per_block):
block.append(ResConvBlock(in_ch, out_ch, out_ch))
in_ch = out_ch
if out_ch != block_out_channels[-1]:
block.append(nn.AvgPool1d(2))
self.down_blocks.append(nn.Sequential(*block))
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
self.num_points = num_points
self.in_channels = in_channels
self.out_dim = out_dim
self.conv_out = nn.Sequential(
nn.GroupNorm(32, block_out_channels[-1]),
nn.SiLU(),
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
)
assert num_points * in_channels == out_dim, \
f"Flatten dimension {out_dim} must equal num_points({num_points}) * in_channels({in_channels})"
def forward(self, x):
x = self.conv_in(x)
for block in self.down_blocks:
x = block(x)
x = self.mid_block(x)
x = self.conv_out(x)
"""
Args:
x: [B, F, E, num_points, channels]
Returns:
x: [B, F, E, out_dim]
"""
orig_shape = x.shape[:-2] # [B, F, E]
# 重塑以处理所有批次
x = x.reshape(-1, self.num_points, self.in_channels) # [B*F*E, num_points, channels]
# 展平点和通道维度
x = x.reshape(x.size(0), -1) # [B*F*E, num_points*channels]
# 恢复批次维度
x = x.reshape(*orig_shape, self.out_dim) # [B, F, E, out_dim]
return x
class BRepFeatureEmbedder(nn.Module):
@ -177,62 +170,37 @@ class BRepFeatureEmbedder(nn.Module):
logger.info(f" num_edge_points: {self.num_edge_points}")
logger.info(f" embed_dim: {self.embed_dim}")
logger.info(f" use_cf: {self.use_cf}")
# Transformer编码器层
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=8, # 从12减少到8,使每个head的维度更大
norm_first=True, # 改为True,先进行归一化
dim_feedforward=self.embed_dim * 4, # 增大FFN维度
dropout=0.1,
activation=F.gelu # 使用GELU激活函数
)
# 添加初始化方法
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
self.transformer = nn.TransformerEncoder(
layer,
num_layers=6, # 从12减少到6层
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False
)
# 应用初始化
self.transformer.apply(_init_weights)
# 添加位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000
surf_out_dim = 3*self.num_surf_points # 面特征展平后维度,8*3=24
edge_out_dim = 3*self.num_edge_points # 边特征展平后维度,2*3=6
self.surfz_encoder = Encoder1D(num_points=self.num_surf_points,in_channels=3,out_dim=surf_out_dim)
self.edgez_encoder = Encoder1D(num_points=self.num_edge_points,in_channels=3,out_dim=edge_out_dim)
# 修改为处理[num_points, 3]形状的输入
self.surfz_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2
)
self.edgez_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2
self.surfz_embed = nn.Sequential(
self.surfz_encoder, # [B, 16, 3] -> [B, 48]
nn.Linear(surf_out_dim, self.embed_dim), # [B, 48] -> [B, embed_dim]
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
# 其他嵌入层保持不变
self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgez_embed = nn.Sequential(
self.edgez_encoder,
nn.Linear(edge_out_dim, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgep_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
@ -243,86 +211,87 @@ class BRepFeatureEmbedder(nn.Module):
# 修改vertp_embed的结构
self.vertp_embed = nn.Sequential(
nn.Linear(3, self.embed_dim // 2),
nn.LayerNorm(self.embed_dim // 2),
nn.ReLU(),
nn.Linear(self.embed_dim // 2, self.embed_dim)
nn.Linear(3, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim)
)
# 添加一个额外的投影层
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim)
# 添加 transformer 初始化
self.transformer = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=8, # 注意力头数,通常是embed_dim的因子
dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍
dropout=0.1,
activation='gelu',
batch_first=False # 因为我们用了transpose(0,1)
),
num_layers=6 # transformer层数
)
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs):
B, F, E, _, _ = edge_ncs.shape
# 处理顶点特征
vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标
vertex_embed = self.vertp_embed(vertex_features)
vertex_embed = self.vertex_proj(vertex_embed) # 添加投影
vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状
# 确保顶点特征参与后续计算
edge_features = torch.cat([
self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1),
vertex_embed.mean(dim=3) # 将顶点特征平均池化
], dim=-1)
# 1. 处理边特征
# 重塑边点云以适应1D编码器
edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points]
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points]
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim]
edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim]
# 2. 处理面特征
surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points]
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points]
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim]
surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim]
# 3. 处理位置编码
# 边位置编码
edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6]
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim]
edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim]
# 面位置编码
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim]
"""
Args:
edge_ncs: [B, F, E, num_edge_points, 3] - 边点云
edge_pos: [B, F, E, 6] - 边位置编码
edge_mask: [B, F, E] - 边掩码
surf_ncs: [B, F, num_surf_points, 3] - 面点云
surf_pos: [B, F, 6] - 面位置编码
vertex_pos: [B, F, E, 2, 3] - 顶点位置
"""
B, F, E = edge_pos.shape[:3]
# 1. 处理顶点特征
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征
logger.info(f"edge_ncs shape: {edge_ncs.shape}")
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
# 3. 处理面特征
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
# 4. 组合特征
if self.use_cf:
# 边特征
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim]
# 组合边特征
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
# 面特征
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim]
# 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
# 组合所有特征
# 拼接所有特征
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_features # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
edge_features, # [B, F*E, embed_dim]
surf_features # [B, F, embed_dim]
], dim=1) # [B, F*E+F, embed_dim]
else:
# 只使用位置编码
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim]
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, F*E, embed_dim]
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_p_embeds # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
edge_features, # [B, F*E, embed_dim]
surf_p_embeds # [B, F, embed_dim]
], dim=1) # [B, F*E+F, embed_dim]
# 5. 处理掩码
if edge_mask is not None:
# 扩展掩码以匹配特征维度
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge]
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face]
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)]
edge_mask = edge_mask.reshape(B, -1) # [B, F*E]
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, F]
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F]
else:
mask = None
# 6. Transformer处理
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim]
return output.transpose(0, 1) # [B, F*E+F, embed_dim]
class SDFTransformer(nn.Module):
"""SDF Transformer编码器"""

Loading…
Cancel
Save