From f7d79eaf525b7d8eca2a7c5bc265783f9b05bc8c Mon Sep 17 00:00:00 2001 From: mckay Date: Wed, 30 Apr 2025 00:44:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=82=E6=95=B0=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/decoder.py | 8 ++++---- brep2sdf/networks/patch_graph.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 0624f7b..f48f19a 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -20,7 +20,7 @@ class Decoder(nn.Module): skip_in: Tuple[int, ...] = (), flag_convex: bool = True, geometric_init: bool = True, - radius_init: float = 1, + radius_init: float = 0.5, beta: float = 100, ) -> None: super().__init__() @@ -66,7 +66,7 @@ class Decoder(nn.Module): else: #siren self.activation = Sine() - self.final_activation = nn.ReLU() + self.final_activation = nn.Tanh() def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: ''' @@ -86,7 +86,7 @@ class Decoder(nn.Module): for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt - + #logger.print_tensor_stats(f"layer-{layer}>x", x) x = lin(x) if layer < self.sdf_layers - 2: x = self.activation(x) @@ -112,7 +112,7 @@ class Decoder(nn.Module): for layer, lin in enumerate(self.sdf_modules): if layer in self.skip_in: x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt - + #logger.print_tensor_stats(f"layer-{layer}>x", x) x = lin(x) if layer < self.sdf_layers - 2: x = self.activation(x) diff --git a/brep2sdf/networks/patch_graph.py b/brep2sdf/networks/patch_graph.py index c9e47a8..5518389 100644 --- a/brep2sdf/networks/patch_graph.py +++ b/brep2sdf/networks/patch_graph.py @@ -95,6 +95,8 @@ class PatchGraph(nn.Module): # 返回 0: 凹边, 1: 凸边, node_faces = node_faces.flatten().to(self.device) num_faces = node_faces.numel() + if num_faces < 1: + print(f"num_faces:{num_faces}") if num_faces == 1: # 这里设置凸边是因为 后续会补一个 f2 = inf, h = min(f1, f2) # 因为 f2 = inf, 所以 h = f1