Browse Source

style: 注释了不必要log

main
mckay 3 months ago
parent
commit
8f0e108696
  1. 5
      brep2sdf/data/data.py
  2. 28
      brep2sdf/networks/encoder.py
  3. 4
      brep2sdf/networks/network.py
  4. 7
      brep2sdf/train.py

5
brep2sdf/data/data.py

@ -99,6 +99,7 @@ class BRepSDFDataset(Dataset):
max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled
)
'''
# 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features:
@ -118,7 +119,7 @@ class BRepSDFDataset(Dataset):
else:
logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3]
@ -206,7 +207,7 @@ class BRepSDFDataset(Dataset):
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices]
logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
#logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e:

28
brep2sdf/networks/encoder.py

@ -246,39 +246,39 @@ class BRepFeatureEmbedder(nn.Module):
'vertex_pos': vertex_pos
}
logger.info("\n=== 输入张量检查 ===")
for name, tensor in input_tensors.items():
logger.print_tensor_stats(name, tensor)
#logger.info("\n=== 输入张量检查 ===")
#for name, tensor in input_tensors.items():
#logger.print_tensor_stats(name, tensor)
# 1. 处理顶点特征
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
logger.print_tensor_stats('vertex_embed', vertex_embed)
# logger.print_tensor_stats('vertex_embed', vertex_embed)
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed)
# logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed)
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
logger.print_tensor_stats('edge_embeds', edge_embeds)
# logger.print_tensor_stats('edge_embeds', edge_embeds)
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
logger.print_tensor_stats('edge_p_embeds', edge_p_embeds)
# logger.print_tensor_stats('edge_p_embeds', edge_p_embeds)
# 3. 处理面特征
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
logger.print_tensor_stats('surf_embeds', surf_embeds)
# logger.print_tensor_stats('surf_embeds', surf_embeds)
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
logger.print_tensor_stats('surf_p_embeds', surf_p_embeds)
# logger.print_tensor_stats('surf_p_embeds', surf_p_embeds)
# 4. 组合特征
if self.use_cf:
# 组合边特征
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
logger.print_tensor_stats('edge_features', edge_features)
# logger.print_tensor_stats('edge_features', edge_features)
# 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
logger.print_tensor_stats('surf_features', surf_features)
# logger.print_tensor_stats('surf_features', surf_features)
# 拼接所有特征
embeds = torch.cat([
@ -300,11 +300,11 @@ class BRepFeatureEmbedder(nn.Module):
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F]
else:
mask = None
logger.debug(f"embeds shape: {embeds.shape}")
#logger.debug(f"embeds shape: {embeds.shape}")
# 6. Transformer处理
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask)
logger.print_tensor_stats('output', output)
logger.debug(f"output shape: {output.shape}")
# logger.print_tensor_stats('output', output)
#logger.debug(f"output shape: {output.shape}")
return output.transpose(0, 1) # [B, F*E+F, embed_dim]
class SDFTransformer(nn.Module):

4
brep2sdf/networks/network.py

@ -101,8 +101,8 @@ class BRepToSDF(nn.Module):
# 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1]
if not sdf.requires_grad:
logger.warning("SDF output does not require grad!")
#if not sdf.requires_grad:
#logger.warning("SDF output does not require grad!")
return sdf

7
brep2sdf/train.py

@ -114,16 +114,19 @@ class Trainer:
self.optimizer.zero_grad()
# 获取数据并移动到设备,同时设置梯度
# 获取数据并移动到设备,同时保留计算图
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True)
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
points = batch['points'].to(self.device).requires_grad_(True)
#logger.print_tensor_stats("batch surf_ncs",surf_ncs)
# 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device)
# 前向传播
@ -143,6 +146,8 @@ class Trainer:
gt_sdf=gt_sdf,
)
#logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs)
# 反向传播和优化
loss.backward()

Loading…
Cancel
Save