From 8f0e108696b97ac5847a1226c1aa59a539e186b8 Mon Sep 17 00:00:00 2001 From: mckay Date: Tue, 3 Dec 2024 21:18:15 +0800 Subject: [PATCH] =?UTF-8?q?style:=20=E6=B3=A8=E9=87=8A=E4=BA=86=E4=B8=8D?= =?UTF-8?q?=E5=BF=85=E8=A6=81log?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/data/data.py | 5 +++-- brep2sdf/networks/encoder.py | 28 ++++++++++++++-------------- brep2sdf/networks/network.py | 4 ++-- brep2sdf/train.py | 7 ++++++- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 5651aa4..5e6fe84 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -99,6 +99,7 @@ class BRepSDFDataset(Dataset): max_edge=self.max_edge, bbox_scaled=self.bbox_scaled ) + ''' # 打印数据形状 logger.debug(f"Processed data shapes for {os.path.basename(brep_path)}:") for value in brep_features: @@ -118,7 +119,7 @@ class BRepSDFDataset(Dataset): else: logger.error(f" {i}: {type(feat)}") raise ValueError(f"Incorrect number of features: {len(brep_features)}") - + ''' # 解包处理后的特征 edge_ncs, edge_pos, edge_mask, surf_ncs, surf_pos, vertex_pos = brep_features sdf_points = sdf_data[:, :3] @@ -206,7 +207,7 @@ class BRepSDFDataset(Dataset): indices = np.random.choice(sdf_np.shape[0], max_points, replace=False) sdf_np = sdf_np[indices] - logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") + #logger.debug(f"Sampled SDF points: {sdf_np.shape[0]} (max: {max_points})") return torch.from_numpy(sdf_np.astype(np.float32)) except Exception as e: diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index b4fdb91..f828f61 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -246,39 +246,39 @@ class BRepFeatureEmbedder(nn.Module): 'vertex_pos': vertex_pos } - logger.info("\n=== 输入张量检查 ===") - for name, tensor in input_tensors.items(): - logger.print_tensor_stats(name, tensor) + #logger.info("\n=== 输入张量检查 ===") + #for name, tensor in input_tensors.items(): + #logger.print_tensor_stats(name, tensor) # 1. 处理顶点特征 vertex_embed = self.vertp_embed(vertex_pos[..., :3]) # [B, F, E, 2, embed_dim] - logger.print_tensor_stats('vertex_embed', vertex_embed) + # logger.print_tensor_stats('vertex_embed', vertex_embed) vertex_embed = self.vertex_proj(vertex_embed) # [B, F, E, 2, embed_dim] - logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) + # logger.print_tensor_stats('vertex_embed(after proj)', vertex_embed) vertex_embed = vertex_embed.mean(dim=3) # [B, F, E, embed_dim] # 2. 处理边特征 edge_embeds = self.edgez_embed(edge_ncs) # [B, F, E, embed_dim] - logger.print_tensor_stats('edge_embeds', edge_embeds) + # logger.print_tensor_stats('edge_embeds', edge_embeds) edge_p_embeds = self.edgep_embed(edge_pos) # [B, F, E, embed_dim] - logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) + # logger.print_tensor_stats('edge_p_embeds', edge_p_embeds) # 3. 处理面特征 surf_embeds = self.surfz_embed(surf_ncs) # [B, F, embed_dim] - logger.print_tensor_stats('surf_embeds', surf_embeds) + # logger.print_tensor_stats('surf_embeds', surf_embeds) surf_p_embeds = self.surfp_embed(surf_pos) # [B, F, embed_dim] - logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) + # logger.print_tensor_stats('surf_p_embeds', surf_p_embeds) # 4. 组合特征 if self.use_cf: # 组合边特征 edge_features = edge_embeds + edge_p_embeds + vertex_embed # [B, F, E, embed_dim] edge_features = edge_features.reshape(B, F*E, -1) # [B, F*E, embed_dim] - logger.print_tensor_stats('edge_features', edge_features) + # logger.print_tensor_stats('edge_features', edge_features) # 组合面特征 surf_features = surf_embeds + surf_p_embeds # [B, F, embed_dim] - logger.print_tensor_stats('surf_features', surf_features) + # logger.print_tensor_stats('surf_features', surf_features) # 拼接所有特征 embeds = torch.cat([ @@ -300,11 +300,11 @@ class BRepFeatureEmbedder(nn.Module): mask = torch.cat([edge_mask, surf_mask], dim=1) # [B, F*E+F] else: mask = None - logger.debug(f"embeds shape: {embeds.shape}") + #logger.debug(f"embeds shape: {embeds.shape}") # 6. Transformer处理 output = self.net(embeds.transpose(0, 1), src_key_padding_mask=mask) - logger.print_tensor_stats('output', output) - logger.debug(f"output shape: {output.shape}") + # logger.print_tensor_stats('output', output) + #logger.debug(f"output shape: {output.shape}") return output.transpose(0, 1) # [B, F*E+F, embed_dim] class SDFTransformer(nn.Module): diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 254ed87..5c36c91 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -101,8 +101,8 @@ class BRepToSDF(nn.Module): # 6. SDF预测 sdf = self.sdf_head(combined_features) # [B, Q, 1] - if not sdf.requires_grad: - logger.warning("SDF output does not require grad!") + #if not sdf.requires_grad: + #logger.warning("SDF output does not require grad!") return sdf diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 0023c8f..ceaccf2 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -114,16 +114,19 @@ class Trainer: self.optimizer.zero_grad() # 获取数据并移动到设备,同时设置梯度 + # 获取数据并移动到设备,同时保留计算图 surf_ncs = batch['surf_ncs'].to(self.device).requires_grad_(True) edge_ncs = batch['edge_ncs'].to(self.device).requires_grad_(True) surf_pos = batch['surf_pos'].to(self.device).requires_grad_(True) edge_pos = batch['edge_pos'].to(self.device).requires_grad_(True) vertex_pos = batch['vertex_pos'].to(self.device).requires_grad_(True) points = batch['points'].to(self.device).requires_grad_(True) + + + #logger.print_tensor_stats("batch surf_ncs",surf_ncs) # 这些不需要梯度 edge_mask = batch['edge_mask'].to(self.device) - points = batch['points'].to(self.device) gt_sdf = batch['sdf'].to(self.device) # 前向传播 @@ -143,6 +146,8 @@ class Trainer: gt_sdf=gt_sdf, ) + #logger.print_tensor_stats("after loss batch surf_ncs",surf_ncs) + # 反向传播和优化 loss.backward()