From 58e000dad336b594780728d35d409c48edf8054b Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 25 Apr 2025 21:33:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5torch=20jit=20script?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/decoder.py | 21 ++++++++++----------- brep2sdf/networks/network.py | 7 ++++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 660b895..373dab9 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -30,6 +30,9 @@ class Decoder(nn.Module): dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] self.sdf_layers = len(dims_sdf) + + # 使用 ModuleList 存储 sdf 层 + self.sdf_modules = nn.ModuleList() for layer in range(0, len(dims_sdf) - 1): if layer + 1 in skip_in: out_dim = dims_sdf[layer + 1] - d_in @@ -43,7 +46,8 @@ class Decoder(nn.Module): else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) - setattr(self, "sdf_"+str(layer), lin) + self.sdf_modules.append(lin) + if geometric_init: if beta > 0: self.activation = nn.Softplus(beta=beta) @@ -55,10 +59,6 @@ class Decoder(nn.Module): self.activation = Sine() self.final_activation = nn.ReLU() - # composite f_i to h - - - def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: ''' :param feature_matrix: 形状为 (B, P, D) 的特征矩阵 @@ -73,10 +73,10 @@ class Decoder(nn.Module): # 展平处理 (B*P, D) x = feature_matrix.view(-1, D) - for layer in range(0, self.sdf_layers - 1): - lin = getattr(self, "sdf_" + str(layer)) + # 使用枚举遍历 sdf_modules + for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: - x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' + x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt x = lin(x) if layer < self.sdf_layers - 2: @@ -100,10 +100,9 @@ class Decoder(nn.Module): # 直接使用输入的特征矩阵,因为形状已经是 (S, D) x = feature_matrix - for layer in range(0, self.sdf_layers - 1): - lin = getattr(self, "sdf_" + str(layer)) + for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: - x = torch.cat([x, x], -1) / np.sqrt(2) + x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt x = lin(x) if layer < self.sdf_layers - 2: diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index e2a45b6..cba4484 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -100,9 +100,9 @@ class Net(nn.Module): feature_vectors = self.encoder.forward(query_points,face_indices_mask) print("feature_vector:", feature_vectors.shape) # 解码 - logger.gpu_memory_stats("encoder farward后") + #logger.gpu_memory_stats("encoder farward后") f_i = self.decoder(feature_vectors) # (B, P) - logger.gpu_memory_stats("decoder farward后") + #logger.gpu_memory_stats("decoder farward后") output = f_i[:, 0] @@ -127,7 +127,7 @@ class Net(nn.Module): output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values - logger.gpu_memory_stats("combine后") + #logger.gpu_memory_stats("combine后") return output @torch.jit.export @@ -137,6 +137,7 @@ class Net(nn.Module): surf_points (P, S): return (P, S) """ + logger.debug(surf_points) feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) f_i = self.decoder.forward_training_volumes(feature_mat) return f_i.squeeze()