Browse Source

fix: can run, but still no grad

main
mckay 6 months ago
parent
commit
fa81eefa54
  1. 249
      brep2sdf/networks/encoder.py

249
brep2sdf/networks/encoder.py

@ -76,7 +76,7 @@ class UNetMidBlock1D(nn.Module):
x = attn(x) x = attn(x)
return x return x
class Encoder1D(nn.Module): class _Encoder1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
@ -114,47 +114,40 @@ class Encoder1D(nn.Module):
x = self.mid_block(x) x = self.mid_block(x)
x = self.conv_out(x) x = self.conv_out(x)
return x return x
class Encoder1D_(nn.Module): class Encoder1D(nn.Module):
"""一维编码器"""
def __init__( def __init__(
self, self,
in_channels: int = 3, num_points: int = 2, # 输入点数
out_channels: int = 256, in_channels: int = 3, # xyz坐标
block_out_channels: Tuple[int] = (64, 128, 256), out_dim: int = 6, # 展平后维度 (2*3)
layers_per_block: int = 2,
): ):
super().__init__() super().__init__()
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 3, padding=1) self.num_points = num_points
self.in_channels = in_channels
self.down_blocks = nn.ModuleList([]) self.out_dim = out_dim
in_ch = block_out_channels[0]
for out_ch in block_out_channels:
block = []
for _ in range(layers_per_block):
block.append(ResConvBlock(in_ch, out_ch, out_ch))
in_ch = out_ch
if out_ch != block_out_channels[-1]:
block.append(nn.AvgPool1d(2))
self.down_blocks.append(nn.Sequential(*block))
self.mid_block = UNetMidBlock1D(
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
)
self.conv_out = nn.Sequential( assert num_points * in_channels == out_dim, \
nn.GroupNorm(32, block_out_channels[-1]), f"Flatten dimension {out_dim} must equal num_points({num_points}) * in_channels({in_channels})"
nn.SiLU(),
nn.Conv1d(block_out_channels[-1], out_channels, 3, padding=1),
)
def forward(self, x): def forward(self, x):
x = self.conv_in(x) """
for block in self.down_blocks: Args:
x = block(x) x: [B, F, E, num_points, channels]
x = self.mid_block(x) Returns:
x = self.conv_out(x) x: [B, F, E, out_dim]
"""
orig_shape = x.shape[:-2] # [B, F, E]
# 重塑以处理所有批次
x = x.reshape(-1, self.num_points, self.in_channels) # [B*F*E, num_points, channels]
# 展平点和通道维度
x = x.reshape(x.size(0), -1) # [B*F*E, num_points*channels]
# 恢复批次维度
x = x.reshape(*orig_shape, self.out_dim) # [B, F, E, out_dim]
return x return x
class BRepFeatureEmbedder(nn.Module): class BRepFeatureEmbedder(nn.Module):
@ -177,62 +170,37 @@ class BRepFeatureEmbedder(nn.Module):
logger.info(f" num_edge_points: {self.num_edge_points}") logger.info(f" num_edge_points: {self.num_edge_points}")
logger.info(f" embed_dim: {self.embed_dim}") logger.info(f" embed_dim: {self.embed_dim}")
logger.info(f" use_cf: {self.use_cf}") logger.info(f" use_cf: {self.use_cf}")
# Transformer编码器层 surf_out_dim = 3*self.num_surf_points # 面特征展平后维度,8*3=24
layer = nn.TransformerEncoderLayer( edge_out_dim = 3*self.num_edge_points # 边特征展平后维度,2*3=6
d_model=self.embed_dim, self.surfz_encoder = Encoder1D(num_points=self.num_surf_points,in_channels=3,out_dim=surf_out_dim)
nhead=8, # 从12减少到8,使每个head的维度更大 self.edgez_encoder = Encoder1D(num_points=self.num_edge_points,in_channels=3,out_dim=edge_out_dim)
norm_first=True, # 改为True,先进行归一化
dim_feedforward=self.embed_dim * 4, # 增大FFN维度
dropout=0.1,
activation=F.gelu # 使用GELU激活函数
)
# 添加初始化方法
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
self.transformer = nn.TransformerEncoder(
layer,
num_layers=6, # 从12减少到6层
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False
)
# 应用初始化
self.transformer.apply(_init_weights)
# 添加位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000
# 修改为处理[num_points, 3]形状的输入 # 修改为处理[num_points, 3]形状的输入
self.surfz_embed = Encoder1D( self.surfz_embed = nn.Sequential(
in_channels=3, self.surfz_encoder, # [B, 16, 3] -> [B, 48]
out_channels=self.embed_dim, nn.Linear(surf_out_dim, self.embed_dim), # [B, 48] -> [B, embed_dim]
block_out_channels=(64, 128, self.embed_dim), nn.LayerNorm(self.embed_dim),
layers_per_block=2 nn.SiLU(),
) nn.Linear(self.embed_dim, self.embed_dim),
self.edgez_embed = Encoder1D(
in_channels=3,
out_channels=self.embed_dim,
block_out_channels=(64, 128, self.embed_dim),
layers_per_block=2
) )
# 其他嵌入层保持不变
self.surfp_embed = nn.Sequential( self.surfp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim), nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim), nn.LayerNorm(self.embed_dim),
nn.SiLU(), nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim), nn.Linear(self.embed_dim, self.embed_dim),
) )
self.edgez_embed = nn.Sequential(
self.edgez_encoder,
nn.Linear(edge_out_dim, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.edgep_embed = nn.Sequential( self.edgep_embed = nn.Sequential(
nn.Linear(6, self.embed_dim), nn.Linear(6, self.embed_dim),
@ -243,86 +211,87 @@ class BRepFeatureEmbedder(nn.Module):
# 修改vertp_embed的结构 # 修改vertp_embed的结构
self.vertp_embed = nn.Sequential( self.vertp_embed = nn.Sequential(
nn.Linear(3, self.embed_dim // 2), nn.Linear(3, self.embed_dim),
nn.LayerNorm(self.embed_dim // 2), nn.LayerNorm(self.embed_dim),
nn.ReLU(), nn.SiLU(),
nn.Linear(self.embed_dim // 2, self.embed_dim) nn.Linear(self.embed_dim, self.embed_dim)
) )
# 添加一个额外的投影层 # 添加一个额外的投影层
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim)
# 添加 transformer 初始化
self.transformer = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=8, # 注意力头数,通常是embed_dim的因子
dim_feedforward=4*self.embed_dim, # 前馈网络维度,通常是embed_dim的4倍
dropout=0.1,
activation='gelu',
batch_first=False # 因为我们用了transpose(0,1)
),
num_layers=6 # transformer层数
)
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs):
B, F, E, _, _ = edge_ncs.shape """
Args:
# 处理顶点特征 edge_ncs: [B, F, E, num_edge_points, 3] - 边点云
vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标 edge_pos: [B, F, E, 6] - 边位置编码
vertex_embed = self.vertp_embed(vertex_features) edge_mask: [B, F, E] - 边掩码
vertex_embed = self.vertex_proj(vertex_embed) # 添加投影 surf_ncs: [B, F, num_surf_points, 3] - 面点云
vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状 surf_pos: [B, F, 6] - 面位置编码
vertex_pos: [B, F, E, 2, 3] - 顶点位置
# 确保顶点特征参与后续计算 """
edge_features = torch.cat([ B, F, E = edge_pos.shape[:3]
self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1),
vertex_embed.mean(dim=3) # 将顶点特征平均池化 # 1. 处理顶点特征
], dim=-1) vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
# 1. 处理边特征 vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 重塑边点云以适应1D编码器
edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] # 2. 处理边特征
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] logger.info(f"edge_ncs shape: {edge_ncs.shape}")
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
# 2. 处理面特征 # 3. 处理面特征
surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim]
surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim]
# 3. 处理位置编码
# 边位置编码
edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6]
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim]
edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim]
# 面位置编码
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim]
# 4. 组合特征 # 4. 组合特征
if self.use_cf: if self.use_cf:
# 边特征 # 组合边特征
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
# 面特征 # 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
# 组合所有特征 # 拼接所有特征
embeds = torch.cat([ embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim] edge_features, # [B, F*E, embed_dim]
surf_features # [B, max_face, embed_dim] surf_features # [B, F, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim] ], dim=1) # [B, F*E+F, embed_dim]
else: else:
# 只使用位置编码 # 只使用位置编码
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, F*E, embed_dim]
embeds = torch.cat([ embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim] edge_features, # [B, F*E, embed_dim]
surf_p_embeds # [B, max_face, embed_dim] surf_p_embeds # [B, F, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim] ], dim=1) # [B, F*E+F, embed_dim]
# 5. 处理掩码 # 5. 处理掩码
if edge_mask is not None: if edge_mask is not None:
# 扩展掩码以匹配特征维度 edge_mask = edge_mask.reshape(B, -1) # [B, F*E]
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, F]
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face] mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F]
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)]
else: else:
mask = None mask = None
# 6. Transformer处理 # 6. Transformer处理
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] return output.transpose(0, 1) # [B, F*E+F, embed_dim]
class SDFTransformer(nn.Module): class SDFTransformer(nn.Module):
"""SDF Transformer编码器""" """SDF Transformer编码器"""

Loading…
Cancel
Save