|
@ -41,7 +41,7 @@ class Decoder(nn.Module): |
|
|
lin = nn.Linear(dims_sdf[layer], out_dim) |
|
|
lin = nn.Linear(dims_sdf[layer], out_dim) |
|
|
if geometric_init: |
|
|
if geometric_init: |
|
|
if layer == self.sdf_layers - 2: |
|
|
if layer == self.sdf_layers - 2: |
|
|
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001) |
|
|
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.1) |
|
|
torch.nn.init.constant_(lin.bias, -radius_init) |
|
|
torch.nn.init.constant_(lin.bias, -radius_init) |
|
|
else: |
|
|
else: |
|
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
@ -59,13 +59,13 @@ class Decoder(nn.Module): |
|
|
nn.Softplus(beta=beta) |
|
|
nn.Softplus(beta=beta) |
|
|
) |
|
|
) |
|
|
if beta > 0: |
|
|
if beta > 0: |
|
|
self.activation = nn.Softplus(beta=beta) |
|
|
self.activation = nn.SiLU() |
|
|
# vanilla relu |
|
|
# vanilla relu |
|
|
else: |
|
|
else: |
|
|
self.activation = nn.ReLU() |
|
|
self.activation = nn.ReLU() |
|
|
else: |
|
|
else: |
|
|
#siren |
|
|
#siren |
|
|
self.activation = Sine() |
|
|
self.activation = nn.SiLU() |
|
|
self.final_activation = nn.Tanh() |
|
|
self.final_activation = nn.Tanh() |
|
|
|
|
|
|
|
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
|
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
|
|