|
|
@ -30,6 +30,9 @@ class Decoder(nn.Module): |
|
|
|
|
|
|
|
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] |
|
|
|
self.sdf_layers = len(dims_sdf) |
|
|
|
|
|
|
|
# 使用 ModuleList 存储 sdf 层 |
|
|
|
self.sdf_modules = nn.ModuleList() |
|
|
|
for layer in range(0, len(dims_sdf) - 1): |
|
|
|
if layer + 1 in skip_in: |
|
|
|
out_dim = dims_sdf[layer + 1] - d_in |
|
|
@ -43,7 +46,8 @@ class Decoder(nn.Module): |
|
|
|
else: |
|
|
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
|
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|
|
|
setattr(self, "sdf_"+str(layer), lin) |
|
|
|
self.sdf_modules.append(lin) |
|
|
|
|
|
|
|
if geometric_init: |
|
|
|
if beta > 0: |
|
|
|
self.activation = nn.Softplus(beta=beta) |
|
|
@ -55,10 +59,6 @@ class Decoder(nn.Module): |
|
|
|
self.activation = Sine() |
|
|
|
self.final_activation = nn.ReLU() |
|
|
|
|
|
|
|
# composite f_i to h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
|
|
|
''' |
|
|
|
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵 |
|
|
@ -73,10 +73,10 @@ class Decoder(nn.Module): |
|
|
|
# 展平处理 (B*P, D) |
|
|
|
x = feature_matrix.view(-1, D) |
|
|
|
|
|
|
|
for layer in range(0, self.sdf_layers - 1): |
|
|
|
lin = getattr(self, "sdf_" + str(layer)) |
|
|
|
# 使用枚举遍历 sdf_modules |
|
|
|
for layer, lin in enumerate(self.sdf_modules): |
|
|
|
if layer in self.skip_in: |
|
|
|
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' |
|
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt |
|
|
|
|
|
|
|
x = lin(x) |
|
|
|
if layer < self.sdf_layers - 2: |
|
|
@ -100,10 +100,9 @@ class Decoder(nn.Module): |
|
|
|
# 直接使用输入的特征矩阵,因为形状已经是 (S, D) |
|
|
|
x = feature_matrix |
|
|
|
|
|
|
|
for layer in range(0, self.sdf_layers - 1): |
|
|
|
lin = getattr(self, "sdf_" + str(layer)) |
|
|
|
for layer, lin in enumerate(self.sdf_modules): |
|
|
|
if layer in self.skip_in: |
|
|
|
x = torch.cat([x, x], -1) / np.sqrt(2) |
|
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt |
|
|
|
|
|
|
|
x = lin(x) |
|
|
|
if layer < self.sdf_layers - 2: |
|
|
|