From b0be9a26f9c23c2de884d052e6a902c738b26712 Mon Sep 17 00:00:00 2001 From: mckay Date: Wed, 27 Nov 2024 00:29:43 +0800 Subject: [PATCH] fix: grad escape --- brep2sdf/networks/encoder.py | 263 ++++++++++++++++++++++------------- brep2sdf/networks/network.py | 117 +++++++++++++++- 2 files changed, 278 insertions(+), 102 deletions(-) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index cda7ab0..a5e978e 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -180,19 +180,37 @@ class BRepFeatureEmbedder(nn.Module): # Transformer编码器层 layer = nn.TransformerEncoderLayer( - d_model=self.embed_dim, - nhead=12, - norm_first=False, - dim_feedforward=1024, - dropout=0.1 + d_model=self.embed_dim, + nhead=8, # 从12减少到8,使每个head的维度更大 + norm_first=True, # 改为True,先进行归一化 + dim_feedforward=self.embed_dim * 4, # 增大FFN维度 + dropout=0.1, + activation=F.gelu # 使用GELU激活函数 ) + + # 添加初始化方法 + def _init_weights(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight, gain=1/math.sqrt(2)) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.ones_(module.weight) + torch.nn.init.zeros_(module.bias) + self.transformer = nn.TransformerEncoder( - layer, - num_layers=12, + layer, + num_layers=6, # 从12减少到6层 norm=nn.LayerNorm(self.embed_dim), enable_nested_tensor=False ) + # 应用初始化 + self.transformer.apply(_init_weights) + + # 添加位置编码 + self.pos_embedding = nn.Parameter(torch.randn(1, 1000, self.embed_dim) * 0.02) # 最大序列长度设为1000 + # 修改为处理[num_points, 3]形状的输入 self.surfz_embed = Encoder1D( in_channels=3, @@ -223,117 +241,162 @@ class BRepFeatureEmbedder(nn.Module): nn.Linear(self.embed_dim, self.embed_dim), ) + # 修改vertp_embed的结构 self.vertp_embed = nn.Sequential( - nn.Linear(6, self.embed_dim), - nn.LayerNorm(self.embed_dim), - nn.SiLU(), - nn.Linear(self.embed_dim, self.embed_dim), + nn.Linear(3, self.embed_dim // 2), + nn.LayerNorm(self.embed_dim // 2), + nn.ReLU(), + nn.Linear(self.embed_dim // 2, self.embed_dim) ) - - def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, data_class=None): - """B-rep特征嵌入器的前向传播 - Args: - edge_ncs: 边归一化特征 [B, max_face, max_edge, num_edge_points, 3] - edge_pos: 边位置 [B, max_face, max_edge, 6] - edge_mask: 边掩码 [B, max_face, max_edge] - surf_ncs: 面归一化特征 [B, max_face, num_surf_points, 3] - surf_pos: 面位置 [B, max_face, 6] - vertex_pos: 顶点位置 [B, max_face, max_edge, 2, 3] + # 添加一个额外的投影层 + self.vertex_proj = nn.Linear(self.embed_dim, self.embed_dim) - Returns: - embeds: [B, max_face*(max_edge+1), embed_dim] - """ - B = self.config.train.batch_size - max_face = self.config.data.max_face - max_edge = self.config.data.max_edge + def forward(self, edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos, **kwargs): + B, F, E, _, _ = edge_ncs.shape - try: - # 1. 处理边特征 - # 重塑边点云以适应1D编码器 - edge_ncs = edge_ncs.reshape(B*max_face*max_edge, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] - edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] - edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] - edge_embeds = edge_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] - - # 2. 处理面特征 - surf_ncs = surf_ncs.reshape(B*max_face, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] - surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] - surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] - surf_embeds = surf_embeds.reshape(B, max_face, -1) # [B, max_face, embed_dim] - - # 3. 处理位置编码 - # 边位置编码 - edge_pos = edge_pos.reshape(B*max_face*max_edge, -1) # [B*max_face*max_edge, 6] - edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] - edge_p_embeds = edge_p_embeds.reshape(B, max_face, max_edge, -1) # [B, max_face, max_edge, embed_dim] - - # 面位置编码 - surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] - - # 4. 组合特征 - if self.use_cf: - # 边特征 - edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] - edge_features = edge_features.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] - - # 面特征 - surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] - - # 组合所有特征 - embeds = torch.cat([ - edge_features, # [B, max_face*max_edge, embed_dim] - surf_features # [B, max_face, embed_dim] - ], dim=1) # [B, max_face*(max_edge+1), embed_dim] - else: - # 只使用位置编码 - edge_features = edge_p_embeds.reshape(B, max_face*max_edge, -1) # [B, max_face*max_edge, embed_dim] - embeds = torch.cat([ - edge_features, # [B, max_face*max_edge, embed_dim] - surf_p_embeds # [B, max_face, embed_dim] - ], dim=1) # [B, max_face*(max_edge+1), embed_dim] - - # 5. 处理掩码 - if edge_mask is not None: - # 扩展掩码以匹配特征维度 - edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] - surf_mask = torch.ones(B, max_face, device=edge_mask.device, dtype=torch.bool) # [B, max_face] - mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] - else: - mask = None + # 处理顶点特征 + vertex_features = vertex_pos.view(B*F*E*2, -1) # 展平顶点坐标 + vertex_embed = self.vertp_embed(vertex_features) + vertex_embed = self.vertex_proj(vertex_embed) # 添加投影 + vertex_embed = vertex_embed.view(B, F, E, 2, -1) # 恢复形状 + + # 确保顶点特征参与后续计算 + edge_features = torch.cat([ + self.edgep_embed(edge_pos.view(B*F*E, -1)).view(B, F, E, -1), + vertex_embed.mean(dim=3) # 将顶点特征平均池化 + ], dim=-1) + + # 1. 处理边特征 + # 重塑边点云以适应1D编码器 + edge_ncs = edge_ncs.reshape(B*F*E, -1, 3).transpose(1, 2) # [B*max_face*max_edge, 3, num_edge_points] + edge_embeds = self.edgez_embed(edge_ncs) # [B*max_face*max_edge, embed_dim, num_edge_points] + edge_embeds = edge_embeds.mean(dim=-1) # [B*max_face*max_edge, embed_dim] + edge_embeds = edge_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] + + # 2. 处理面特征 + surf_ncs = surf_ncs.reshape(B*F, -1, 3).transpose(1, 2) # [B*max_face, 3, num_surf_points] + surf_embeds = self.surfz_embed(surf_ncs) # [B*max_face, embed_dim, num_surf_points] + surf_embeds = surf_embeds.mean(dim=-1) # [B*max_face, embed_dim] + surf_embeds = surf_embeds.reshape(B, F, -1) # [B, max_face, embed_dim] + + # 3. 处理位置编码 + # 边位置编码 + edge_pos = edge_pos.reshape(B*F*E, -1) # [B*max_face*max_edge, 6] + edge_p_embeds = self.edgep_embed(edge_pos) # [B*max_face*max_edge, embed_dim] + edge_p_embeds = edge_p_embeds.reshape(B, F, E, -1) # [B, max_face, max_edge, embed_dim] + + # 面位置编码 + surf_p_embeds = self.surfp_embed(surf_pos) # [B, max_face, embed_dim] + + # 4. 组合特征 + if self.use_cf: + # 边特征 + edge_features = edge_embeds + edge_p_embeds # [B, max_face, max_edge, embed_dim] + edge_features = edge_features.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] - # 6. Transformer处理 - output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) - return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] + # 面特征 + surf_features = surf_embeds + surf_p_embeds # [B, max_face, embed_dim] - except Exception as e: - logger.error(f"Error in BRepFeatureEmbedder forward pass:") - logger.error(f" Error message: {str(e)}") - logger.error(f" Input shapes:") - logger.error(f" edge_ncs: {edge_ncs.shape}") - logger.error(f" edge_pos: {edge_pos.shape}") - logger.error(f" edge_mask: {edge_mask.shape}") - logger.error(f" surf_ncs: {surf_ncs.shape}") - logger.error(f" surf_pos: {surf_pos.shape}") - logger.error(f" vertex_pos: {vertex_pos.shape}") - raise + # 组合所有特征 + embeds = torch.cat([ + edge_features, # [B, max_face*max_edge, embed_dim] + surf_features # [B, max_face, embed_dim] + ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + else: + # 只使用位置编码 + edge_features = edge_p_embeds.reshape(B, F*E, -1) # [B, max_face*max_edge, embed_dim] + embeds = torch.cat([ + edge_features, # [B, max_face*max_edge, embed_dim] + surf_p_embeds # [B, max_face, embed_dim] + ], dim=1) # [B, max_face*(max_edge+1), embed_dim] + + # 5. 处理掩码 + if edge_mask is not None: + # 扩展掩码以匹配特征维度 + edge_mask = edge_mask.reshape(B, -1) # [B, max_face*max_edge] + surf_mask = torch.ones(B, F, device=edge_mask.device, dtype=torch.bool) # [B, max_face] + mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, max_face*(max_edge+1)] + else: + mask = None + + # 6. Transformer处理 + output = self.transformer(embeds.transpose(0, 1), src_key_padding_mask=mask) + return output.transpose(0, 1) # 确保输出维度为 [B, seq_len, embed_dim] class SDFTransformer(nn.Module): """SDF Transformer编码器""" - def __init__(self, embed_dim: int = 768, num_layers: int = 6): + def __init__(self, embed_dim: int = 192, num_layers: int = 6): # 改为192以匹配BRepFeatureEmbedder super().__init__() + + # 1. 添加位置编码 + self.pos_embedding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.02) + + # 2. 修改Transformer层配置 layer = nn.TransformerEncoderLayer( d_model=embed_dim, - nhead=8, - dim_feedforward=1024, + nhead=4, # 减少头数,使每个头的维度更大 + dim_feedforward=embed_dim * 2, # 减小FFN维度 dropout=0.1, batch_first=True, - norm_first=False # 修改这里:设置为False + norm_first=True, # 使用Pre-LN结构 + activation=F.gelu + ) + + # 3. 添加梯度缩放因子 + self.attention_scale = math.sqrt(embed_dim) + + # 4. 自定义初始化 + def _init_weights(module): + if isinstance(module, nn.Linear): + # 使用较小的初始化范围 + nn.init.xavier_uniform_(module.weight, gain=0.1) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + # 特别处理注意力层 + in_proj_weight = module.in_proj_weight + out_proj_weight = module.out_proj.weight + + nn.init.xavier_uniform_(in_proj_weight, gain=0.1) + nn.init.xavier_uniform_(out_proj_weight, gain=0.1) + + if module.in_proj_bias is not None: + nn.init.zeros_(module.in_proj_bias) + if module.out_proj.bias is not None: + nn.init.zeros_(module.out_proj.bias) + + self.transformer = nn.TransformerEncoder( + layer, + num_layers, + norm=nn.LayerNorm(embed_dim) ) - self.transformer = nn.TransformerEncoder(layer, num_layers) + + # 应用初始化 + self.transformer.apply(_init_weights) + + # 5. 添加残差缩放 + self.residual_scale = nn.Parameter(torch.ones(1) * 0.1) def forward(self, x, mask=None): - return self.transformer(x, src_key_padding_mask=mask) + # 添加位置编码 + seq_len = x.size(1) + x = x + self.pos_embedding[:, :seq_len, :] + + # 缩放注意力分数 + x = x * self.attention_scale + + # 前向传播 + for layer in self.transformer.layers: + # 手动添加残差连接和缩放 + identity = x + x = layer(x, src_key_padding_mask=mask) + x = identity + x * self.residual_scale + + return self.transformer.norm(x) class SDFHead(nn.Module): """SDF预测头""" diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 889d61d..b18dd42 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -9,6 +9,7 @@ from brep2sdf.networks.encoder import BRepFeatureEmbedder from brep2sdf.networks.decoder import SDFHead, SDFTransformer + class BRepToSDF(nn.Module): def __init__(self, config=None): super().__init__() @@ -121,7 +122,7 @@ class BRepToSDF(nn.Module): -def main(): +def _main(): # 获取配置 config = get_default_config() @@ -167,5 +168,117 @@ def main(): logger.error(f"Error during forward pass: {str(e)}") raise +def train_step(model, batch, optimizer, criterion): + """单步训练""" + # 确保模型处于训练模式 + model.train() + + # 检查并设置所有参数的requires_grad + for name, param in model.named_parameters(): + if not param.requires_grad: + logger.warning(f"参数 {name} 的requires_grad为False,现在设置为True") + param.requires_grad = True + + # 将所有输入转为requires_grad=True + batch['query_points'].requires_grad_(True) + + # 清零梯度 + optimizer.zero_grad() + + # 前向传播 + pred_sdf = model( + edge_ncs=batch['edge_ncs'], + edge_pos=batch['edge_pos'], + edge_mask=batch['edge_mask'], + surf_ncs=batch['surf_ncs'], + surf_pos=batch['surf_pos'], + vertex_pos=batch['vertex_pos'], + query_points=batch['query_points'] + ) + + # 计算损失 + loss = criterion(pred_sdf, batch['gt_sdf']) + + # 检查损失是否有效 + if not torch.isfinite(loss): + logger.error(f"损失值无效: {loss.item()}") + raise ValueError("损失值无效") + + # 反向传播 + loss.backward() + + # 检查梯度 + total_norm = 0 + for name, param in model.named_parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + logger.info(f" {name}: grad_norm = {param_norm.item()}") + else: + logger.warning(f" {name}: No gradient!") + total_norm = total_norm ** 0.5 + logger.info(f"梯度总范数: {total_norm}") + + # 梯度裁剪(可选) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # 更新参数 + optimizer.step() + + return loss.item() + +def train(model, config, num_epochs=10): + """模拟训练过程""" + # 初始化优化器 + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + from brep2sdf.networks.loss import Brep2SDFLoss + # 初始化损失函数 + clamping_distance = config.train.clamping_distance + criterion = Brep2SDFLoss( + enforce_minmax= (clamping_distance > 0), + clamping_distance= clamping_distance + ) + + # 生成模拟数据 + batch_size = config.train.batch_size + max_face = config.data.max_face + max_edge = config.data.max_edge + num_surf_points = config.model.num_surf_points + num_edge_points = config.model.num_edge_points + + # 模拟一个batch的数据 + batch = { + 'edge_ncs': torch.randn(batch_size, max_face, max_edge, num_edge_points, 3), + 'edge_pos': torch.randn(batch_size, max_face, max_edge, 6), + 'edge_mask': torch.ones(batch_size, max_face, max_edge, dtype=torch.bool), + 'surf_ncs': torch.randn(batch_size, max_face, num_surf_points, 3), + 'surf_pos': torch.randn(batch_size, max_face, 6), + 'vertex_pos': torch.randn(batch_size, max_face, max_edge, 2, 3), + 'query_points': torch.randn(batch_size, 1000, 3), # 1000个查询点 + 'gt_sdf': torch.randn(batch_size, 1000, 1) # 模拟的GT SDF值 + } + + # 训练循环 + for epoch in range(2): + try: + loss = train_step(model, batch, optimizer, criterion) + logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.6f}") + + except Exception as e: + logger.error(f"Error during training:") + logger.error(f" {str(e)}") + raise + +def main(): + # 获取配置 + config = get_default_config() + + # 初始化模型 + model = BRepToSDF(config=config) + + # 开始训练 + train(model, config) + if __name__ == "__main__": - main() \ No newline at end of file + main()