Browse Source

style: 注释了不必要log

main
mckay 6 months ago
parent
commit
8f0e108696
  1. 5
      brep2sdf/data/data.py
  2. 28
      brep2sdf/networks/encoder.py
  3. 4
      brep2sdf/networks/network.py
  4. 7
      brep2sdf/train.py

5
brep2sdf/data/data.py

@ -99,6 +99,7 @@ class BRepSDFDataset(Dataset):
max_edge=self.max_edge, max_edge=self.max_edge,
bbox_scaled=self.bbox_scaled bbox_scaled=self.bbox_scaled
) )
'''
# 打印数据形状 # 打印数据形状
logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:") logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:")
for value in brep_features: for value in brep_features:
@ -118,7 +119,7 @@ class BRepSDFDataset(Dataset):
else: else:
logger.error(f" {i}: {type(feat)}") logger.error(f" {i}: {type(feat)}")
raise ValueError(f"Incorrect number of features: {len(brep_features)}") raise ValueError(f"Incorrect number of features: {len(brep_features)}")
'''
# 解包处理后的特征 # 解包处理后的特征
edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features
sdf_points = sdf_data[:, :3] sdf_points = sdf_data[:, :3]
@ -206,7 +207,7 @@ class BRepSDFDataset(Dataset):
indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) indices = np.random.choice(sdf_np.shape[0], max_points, replace=False)
sdf_np = sdf_np[indices] sdf_np = sdf_np[indices]
logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") #logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})")
return torch.from_numpy(sdf_np.astype(np.float32)) return torch.from_numpy(sdf_np.astype(np.float32))
except Exception as e: except Exception as e:

28
brep2sdf/networks/encoder.py

@ -246,39 +246,39 @@ class BRepFeatureEmbedder(nn.Module):
'vertex_pos': vertex_pos 'vertex_pos': vertex_pos
} }
logger.info("\n=== 输入张量检查 ===") #logger.info("\n=== 输入张量检查 ===")
for name, tensor in input_tensors.items(): #for name, tensor in input_tensors.items():
logger.print_tensor_stats(name, tensor) #logger.print_tensor_stats(name, tensor)
# 1. 处理顶点特征 # 1. 处理顶点特征
vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim]
logger.print_tensor_stats('vertex_embed', vertex_embed) # logger.print_tensor_stats('vertex_embed', vertex_embed)
vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim]
logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) # logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed)
vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim]
# 2. 处理边特征 # 2. 处理边特征
edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim]
logger.print_tensor_stats('edge_embeds', edge_embeds) # logger.print_tensor_stats('edge_embeds', edge_embeds)
edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim]
logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) # logger.print_tensor_stats('edge_p_embeds', edge_p_embeds)
# 3. 处理面特征 # 3. 处理面特征
surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim]
logger.print_tensor_stats('surf_embeds', surf_embeds) # logger.print_tensor_stats('surf_embeds', surf_embeds)
surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim]
logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) # logger.print_tensor_stats('surf_p_embeds', surf_p_embeds)
# 4. 组合特征 # 4. 组合特征
if self.use_cf: if self.use_cf:
# 组合边特征 # 组合边特征
edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim]
edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim]
logger.print_tensor_stats('edge_features', edge_features) # logger.print_tensor_stats('edge_features', edge_features)
# 组合面特征 # 组合面特征
surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim]
logger.print_tensor_stats('surf_features', surf_features) # logger.print_tensor_stats('surf_features', surf_features)
# 拼接所有特征 # 拼接所有特征
embeds = torch.cat([ embeds = torch.cat([
@ -300,11 +300,11 @@ class BRepFeatureEmbedder(nn.Module):
mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F]
else: else:
mask = None mask = None
logger.debug(f"embeds shape: {embeds.shape}") #logger.debug(f"embeds shape: {embeds.shape}")
# 6. Transformer处理 # 6. Transformer处理
output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask)
logger.print_tensor_stats('output', output) # logger.print_tensor_stats('output', output)
logger.debug(f"output shape: {output.shape}") #logger.debug(f"output shape: {output.shape}")
return output.transpose(0, 1) # [B, F*E+F, embed_dim] return output.transpose(0, 1) # [B, F*E+F, embed_dim]
class SDFTransformer(nn.Module): class SDFTransformer(nn.Module):

4
brep2sdf/networks/network.py

@ -101,8 +101,8 @@ class BRepToSDF(nn.Module):
# 6. SDF预测 # 6. SDF预测
sdf = self.sdf_head(combined_features) # [B, Q, 1] sdf = self.sdf_head(combined_features) # [B, Q, 1]
if not sdf.requires_grad: #if not sdf.requires_grad:
logger.warning("SDF output does not require grad!") #logger.warning("SDF output does not require grad!")
return sdf return sdf

7
brep2sdf/train.py

@ -114,6 +114,7 @@ class Trainer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
# 获取数据并移动到设备,同时设置梯度 # 获取数据并移动到设备,同时设置梯度
# 获取数据并移动到设备,同时保留计算图
surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True)
edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True)
surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True)
@ -121,9 +122,11 @@ class Trainer:
vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True)
points = batch['points'].to(self.device).requires_grad_(True) points = batch['points'].to(self.device).requires_grad_(True)
#logger.print_tensor_stats("batch surf_ncs",surf_ncs)
# 这些不需要梯度 # 这些不需要梯度
edge_mask = batch['edge_mask'].to(self.device) edge_mask = batch['edge_mask'].to(self.device)
points = batch['points'].to(self.device)
gt_sdf = batch['sdf'].to(self.device) gt_sdf = batch['sdf'].to(self.device)
# 前向传播 # 前向传播
@ -143,6 +146,8 @@ class Trainer:
gt_sdf=gt_sdf, gt_sdf=gt_sdf,
) )
#logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs)
# 反向传播和优化 # 反向传播和优化
loss.backward() loss.backward()

Loading…
Cancel
Save