|
|
@ -20,7 +20,7 @@ class Decoder(nn.Module): |
|
|
|
skip_in: Tuple[int, ...] = (), |
|
|
|
flag_convex: bool = True, |
|
|
|
geometric_init: bool = True, |
|
|
|
radius_init: float = 1, |
|
|
|
radius_init: float = 0.5, |
|
|
|
beta: float = 100, |
|
|
|
) -> None: |
|
|
|
super().__init__() |
|
|
@ -66,7 +66,7 @@ class Decoder(nn.Module): |
|
|
|
else: |
|
|
|
#siren |
|
|
|
self.activation = Sine() |
|
|
|
self.final_activation = nn.ReLU() |
|
|
|
self.final_activation = nn.Tanh() |
|
|
|
|
|
|
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
|
|
|
''' |
|
|
@ -86,7 +86,7 @@ class Decoder(nn.Module): |
|
|
|
for layer, lin in enumerate(self.sdf_modules): |
|
|
|
if layer in self.skip_in: |
|
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt |
|
|
|
|
|
|
|
#logger.print_tensor_stats(f"layer-{layer}>x", x) |
|
|
|
x = lin(x) |
|
|
|
if layer < self.sdf_layers - 2: |
|
|
|
x = self.activation(x) |
|
|
@ -112,7 +112,7 @@ class Decoder(nn.Module): |
|
|
|
for layer, lin in enumerate(self.sdf_modules): |
|
|
|
if layer in self.skip_in: |
|
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt |
|
|
|
|
|
|
|
#logger.print_tensor_stats(f"layer-{layer}>x", x) |
|
|
|
x = lin(x) |
|
|
|
if layer < self.sdf_layers - 2: |
|
|
|
x = self.activation(x) |
|
|
|