Browse Source

参数调整

final
mckay 1 month ago
parent
commit
f7d79eaf52
  1. 8
      brep2sdf/networks/decoder.py
  2. 2
      brep2sdf/networks/patch_graph.py

8
brep2sdf/networks/decoder.py

@ -20,7 +20,7 @@ class Decoder(nn.Module):
skip_in: Tuple[int, ...] = (),
flag_convex: bool = True,
geometric_init: bool = True,
radius_init: float = 1,
radius_init: float = 0.5,
beta: float = 100,
) -> None:
super().__init__()
@ -66,7 +66,7 @@ class Decoder(nn.Module):
else:
#siren
self.activation = Sine()
self.final_activation = nn.ReLU()
self.final_activation = nn.Tanh()
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
@ -86,7 +86,7 @@ class Decoder(nn.Module):
for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
#logger.print_tensor_stats(f"layer-{layer}>x", x)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
@ -112,7 +112,7 @@ class Decoder(nn.Module):
for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
#logger.print_tensor_stats(f"layer-{layer}>x", x)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)

2
brep2sdf/networks/patch_graph.py

@ -95,6 +95,8 @@ class PatchGraph(nn.Module):
# 返回 0: 凹边, 1: 凸边,
node_faces = node_faces.flatten().to(self.device)
num_faces = node_faces.numel()
if num_faces < 1:
print(f"num_faces:{num_faces}")
if num_faces == 1:
# 这里设置凸边是因为 后续会补一个 f2 = inf, h = min(f1, f2)
# 因为 f2 = inf, 所以 h = f1

Loading…
Cancel
Save