Browse Source

fix: grad escape

main
mckay 4 months ago
parent
commit
b0be9a26f9
  1. 263
      brep2sdf/networks/encoder.py
  2. 117
      brep2sdf/networks/network.py

263
brep2sdf/networks/encoder.py

@ -180,19 +180,37 @@ class BRepFeatureEmbedder(nn.Module):
# Transformer编码器层
layer = nn.TransformerEncoderLayer(
d_model=self.embed_dim,
nhead=12,
norm_first=False,
dim_feedforward=1024,
dropout=0.1
d_model=self.embed_dim,
nhead=8, # 从12减少到8,使每个head的维度更大
norm_first=True, # 改为True,先进行归一化
dim_feedforward=self.embed_dim * 4, # 增大FFN维度
dropout=0.1,
activation=F.gelu # 使用GELU激活函数
)
# 添加初始化方法
def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
self.transformer = nn.TransformerEncoder(
layer,
num_layers=12,
layer,
num_layers=6, # 从12减少到6层
norm=nn.LayerNorm(self.embed_dim),
enable_nested_tensor=False
)
# 应用初始化
self.transformer.apply(_init_weights)
# 添加位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000
# 修改为处理[num_points, 3]形状的输入
self.surfz_embed = Encoder1D(
in_channels=3,
@ -223,117 +241,162 @@ class BRepFeatureEmbedder(nn.Module):
nn.Linear(self.embed_dim, self.embed_dim),
)
# 修改vertp_embed的结构
self.vertp_embed = nn.Sequential(
nn.Linear(6, self.embed_dim),
nn.LayerNorm(self.embed_dim),
nn.SiLU(),
nn.Linear(self.embed_dim, self.embed_dim),
nn.Linear(3, self.embed_dim // 2),
nn.LayerNorm(self.embed_dim // 2),
nn.ReLU(),
nn.Linear(self.embed_dim // 2, self.embed_dim)
)
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None):
"""B-rep特征嵌入器的前向传播
Args:
edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3]
edge_pos: 边位置 [B, max_face, max_edge, 6]
edge_mask: 边掩码 [B, max_face, max_edge]
surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3]
surf_pos: 面位置 [B, max_face, 6]
vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3]
# 添加一个额外的投影层
self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim)
Returns:
embeds: [B, max_face*(max_edge+1), embed_dim]
"""
B = self.config.train.batch_size
max_face = self.config.data.max_face
max_edge = self.config.data.max_edge
def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs):
B, F, E, _, _ = edge_ncs.shape
try:
# 1. 处理边特征
# 重塑边点云以适应1D编码器
edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points]
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points]
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim]
edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim]
# 2. 处理面特征
surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points]
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points]
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim]
surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim]
# 3. 处理位置编码
# 边位置编码
edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6]
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim]
edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim]
# 面位置编码
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim]
# 4. 组合特征
if self.use_cf:
# 边特征
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim]
edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim]
# 面特征
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim]
# 组合所有特征
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_features # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
else:
# 只使用位置编码
edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim]
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_p_embeds # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
# 5. 处理掩码
if edge_mask is not None:
# 扩展掩码以匹配特征维度
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge]
surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face]
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)]
else:
mask = None
# 处理顶点特征
vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标
vertex_embed = self.vertp_embed(vertex_features)
vertex_embed = self.vertex_proj(vertex_embed) # 添加投影
vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状
# 确保顶点特征参与后续计算
edge_features = torch.cat([
self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1),
vertex_embed.mean(dim=3) # 将顶点特征平均池化
], dim=-1)
# 1. 处理边特征
# 重塑边点云以适应1D编码器
edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points]
edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points]
edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim]
edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim]
# 2. 处理面特征
surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points]
surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points]
surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim]
surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim]
# 3. 处理位置编码
# 边位置编码
edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6]
edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim]
edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim]
# 面位置编码
surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim]
# 4. 组合特征
if self.use_cf:
# 边特征
edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim]
# 6. Transformer处理
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim]
# 面特征
surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim]
except Exception as e:
logger.error(f"Error in BRepFeatureEmbedder forward pass:")
logger.error(f" Error message: {str(e)}")
logger.error(f" Input shapes:")
logger.error(f" edge_ncs: {edge_ncs.shape}")
logger.error(f" edge_pos: {edge_pos.shape}")
logger.error(f" edge_mask: {edge_mask.shape}")
logger.error(f" surf_ncs: {surf_ncs.shape}")
logger.error(f" surf_pos: {surf_pos.shape}")
logger.error(f" vertex_pos: {vertex_pos.shape}")
raise
# 组合所有特征
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_features # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
else:
# 只使用位置编码
edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim]
embeds = torch.cat([
edge_features, # [B, max_face*max_edge, embed_dim]
surf_p_embeds # [B, max_face, embed_dim]
], dim=1) # [B, max_face*(max_edge+1), embed_dim]
# 5. 处理掩码
if edge_mask is not None:
# 扩展掩码以匹配特征维度
edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge]
surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face]
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)]
else:
mask = None
# 6. Transformer处理
output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask)
return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim]
class SDFTransformer(nn.Module):
"""SDF Transformer编码器"""
def __init__(self, embed_dim: int = 768, num_layers: int = 6):
def __init__(self, embed_dim: int = 192, num_layers: int = 6): # 改为192以匹配BRepFeatureEmbedder
super().__init__()
# 1. 添加位置编码
self.pos_embedding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.02)
# 2. 修改Transformer层配置
layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=8,
dim_feedforward=1024,
nhead=4, # 减少头数,使每个头的维度更大
dim_feedforward=embed_dim * 2, # 减小FFN维度
dropout=0.1,
batch_first=True,
norm_first=False # 修改这里:设置为False
norm_first=True, # 使用Pre-LN结构
activation=F.gelu
)
# 3. 添加梯度缩放因子
self.attention_scale = math.sqrt(embed_dim)
# 4. 自定义初始化
def _init_weights(module):
if isinstance(module, nn.Linear):
# 使用较小的初始化范围
nn.init.xavier_uniform_(module.weight, gain=0.1)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
# 特别处理注意力层
in_proj_weight = module.in_proj_weight
out_proj_weight = module.out_proj.weight
nn.init.xavier_uniform_(in_proj_weight, gain=0.1)
nn.init.xavier_uniform_(out_proj_weight, gain=0.1)
if module.in_proj_bias is not None:
nn.init.zeros_(module.in_proj_bias)
if module.out_proj.bias is not None:
nn.init.zeros_(module.out_proj.bias)
self.transformer = nn.TransformerEncoder(
layer,
num_layers,
norm=nn.LayerNorm(embed_dim)
)
self.transformer = nn.TransformerEncoder(layer, num_layers)
# 应用初始化
self.transformer.apply(_init_weights)
# 5. 添加残差缩放
self.residual_scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x, mask=None):
return self.transformer(x, src_key_padding_mask=mask)
# 添加位置编码
seq_len = x.size(1)
x = x + self.pos_embedding[:, :seq_len, :]
# 缩放注意力分数
x = x * self.attention_scale
# 前向传播
for layer in self.transformer.layers:
# 手动添加残差连接和缩放
identity = x
x = layer(x, src_key_padding_mask=mask)
x = identity + x * self.residual_scale
return self.transformer.norm(x)
class SDFHead(nn.Module):
"""SDF预测头"""

117
brep2sdf/networks/network.py

@ -9,6 +9,7 @@ from brep2sdf.networks.encoder import BRepFeatureEmbedder
from brep2sdf.networks.decoder import SDFHead, SDFTransformer
class BRepToSDF(nn.Module):
def __init__(self, config=None):
super().__init__()
@ -121,7 +122,7 @@ class BRepToSDF(nn.Module):
def main():
def _main():
# 获取配置
config = get_default_config()
@ -167,5 +168,117 @@ def main():
logger.error(f"Error during forward pass: {str(e)}")
raise
def train_step(model, batch, optimizer, criterion):
"""单步训练"""
# 确保模型处于训练模式
model.train()
# 检查并设置所有参数的requires_grad
for name, param in model.named_parameters():
if not param.requires_grad:
logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True")
param.requires_grad = True
# 将所有输入转为requires_grad=True
batch['query_points'].requires_grad_(True)
# 清零梯度
optimizer.zero_grad()
# 前向传播
pred_sdf = model(
edge_ncs=batch['edge_ncs'],
edge_pos=batch['edge_pos'],
edge_mask=batch['edge_mask'],
surf_ncs=batch['surf_ncs'],
surf_pos=batch['surf_pos'],
vertex_pos=batch['vertex_pos'],
query_points=batch['query_points']
)
# 计算损失
loss = criterion(pred_sdf, batch['gt_sdf'])
# 检查损失是否有效
if not torch.isfinite(loss):
logger.error(f"损失值无效: {loss.item()}")
raise ValueError("损失值无效")
# 反向传播
loss.backward()
# 检查梯度
total_norm = 0
for name, param in model.named_parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
logger.info(f" {name}: grad_norm = {param_norm.item()}")
else:
logger.warning(f" {name}: No gradient!")
total_norm = total_norm ** 0.5
logger.info(f"梯度总范数: {total_norm}")
# 梯度裁剪(可选)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 更新参数
optimizer.step()
return loss.item()
def train(model, config, num_epochs=10):
"""模拟训练过程"""
# 初始化优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
from brep2sdf.networks.loss import Brep2SDFLoss
# 初始化损失函数
clamping_distance = config.train.clamping_distance
criterion = Brep2SDFLoss(
enforce_minmax= (clamping_distance > 0),
clamping_distance= clamping_distance
)
# 生成模拟数据
batch_size = config.train.batch_size
max_face = config.data.max_face
max_edge = config.data.max_edge
num_surf_points = config.model.num_surf_points
num_edge_points = config.model.num_edge_points
# 模拟一个batch的数据
batch = {
'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3),
'edge_pos': torch.randn(batch_size, max_face, max_edge, 6),
'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool),
'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3),
'surf_pos': torch.randn(batch_size, max_face, 6),
'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3),
'query_points': torch.randn(batch_size, 1000, 3), # 1000个查询点
'gt_sdf': torch.randn(batch_size, 1000, 1) # 模拟的GT SDF值
}
# 训练循环
for epoch in range(2):
try:
loss = train_step(model, batch, optimizer, criterion)
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}")
except Exception as e:
logger.error(f"Error during training:")
logger.error(f" {str(e)}")
raise
def main():
# 获取配置
config = get_default_config()
# 初始化模型
model = BRepToSDF(config=config)
# 开始训练
train(model, config)
if __name__ == "__main__":
main()
main()

Loading…
Cancel
Save