You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
3.3 KiB
109 lines
3.3 KiB
#!/usr/bin/env python3
|
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
latent_size,
|
|
dims,
|
|
dropout=None,
|
|
dropout_prob=0.0,
|
|
norm_layers=(),
|
|
latent_in=(),
|
|
weight_norm=False,
|
|
xyz_in_all=None,
|
|
use_tanh=False,
|
|
latent_dropout=False,
|
|
):
|
|
super(Decoder, self).__init__()
|
|
|
|
def make_sequence():
|
|
return []
|
|
|
|
dims = [latent_size + 3] + dims + [1]
|
|
|
|
self.num_layers = len(dims)
|
|
self.norm_layers = norm_layers
|
|
self.latent_in = latent_in
|
|
self.latent_dropout = latent_dropout
|
|
if self.latent_dropout:
|
|
self.lat_dp = nn.Dropout(0.2)
|
|
|
|
self.xyz_in_all = xyz_in_all
|
|
self.weight_norm = weight_norm
|
|
|
|
for layer in range(0, self.num_layers - 1):
|
|
if layer + 1 in latent_in:
|
|
out_dim = dims[layer + 1] - dims[0]
|
|
else:
|
|
out_dim = dims[layer + 1]
|
|
if self.xyz_in_all and layer != self.num_layers - 2:
|
|
out_dim -= 3
|
|
|
|
if weight_norm and layer in self.norm_layers:
|
|
setattr(
|
|
self,
|
|
"lin" + str(layer),
|
|
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
|
|
)
|
|
else:
|
|
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
|
|
|
|
if (
|
|
(not weight_norm)
|
|
and self.norm_layers is not None
|
|
and layer in self.norm_layers
|
|
):
|
|
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
|
|
|
|
self.use_tanh = use_tanh
|
|
if use_tanh:
|
|
self.tanh = nn.Tanh()
|
|
self.relu = nn.ReLU()
|
|
|
|
self.dropout_prob = dropout_prob
|
|
self.dropout = dropout
|
|
self.th = nn.Tanh()
|
|
|
|
# input: N x (L+3)
|
|
def forward(self, input):
|
|
xyz = input[:, -3:]
|
|
|
|
if input.shape[1] > 3 and self.latent_dropout:
|
|
latent_vecs = input[:, :-3]
|
|
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
|
|
x = torch.cat([latent_vecs, xyz], 1)
|
|
else:
|
|
x = input
|
|
|
|
for layer in range(0, self.num_layers - 1):
|
|
lin = getattr(self, "lin" + str(layer))
|
|
if layer in self.latent_in:
|
|
x = torch.cat([x, input], 1)
|
|
elif layer != 0 and self.xyz_in_all:
|
|
x = torch.cat([x, xyz], 1)
|
|
x = lin(x)
|
|
# last layer Tanh
|
|
if layer == self.num_layers - 2 and self.use_tanh:
|
|
x = self.tanh(x)
|
|
if layer < self.num_layers - 2:
|
|
if (
|
|
self.norm_layers is not None
|
|
and layer in self.norm_layers
|
|
and not self.weight_norm
|
|
):
|
|
bn = getattr(self, "bn" + str(layer))
|
|
x = bn(x)
|
|
x = self.relu(x)
|
|
if self.dropout is not None and layer in self.dropout:
|
|
x = F.dropout(x, p=self.dropout_prob, training=self.training)
|
|
|
|
if hasattr(self, "th"):
|
|
x = self.th(x)
|
|
|
|
return x
|
|
|