You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

109 lines
3.3 KiB

#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import torch.nn as nn
import torch
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(
self,
latent_size,
dims,
dropout=None,
dropout_prob=0.0,
norm_layers=(),
latent_in=(),
weight_norm=False,
xyz_in_all=None,
use_tanh=False,
latent_dropout=False,
):
super(Decoder, self).__init__()
def make_sequence():
return []
dims = [latent_size + 3] + dims + [1]
self.num_layers = len(dims)
self.norm_layers = norm_layers
self.latent_in = latent_in
self.latent_dropout = latent_dropout
if self.latent_dropout:
self.lat_dp = nn.Dropout(0.2)
self.xyz_in_all = xyz_in_all
self.weight_norm = weight_norm
for layer in range(0, self.num_layers - 1):
if layer + 1 in latent_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
if self.xyz_in_all and layer != self.num_layers - 2:
out_dim -= 3
if weight_norm and layer in self.norm_layers:
setattr(
self,
"lin" + str(layer),
nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
)
else:
setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))
if (
(not weight_norm)
and self.norm_layers is not None
and layer in self.norm_layers
):
setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))
self.use_tanh = use_tanh
if use_tanh:
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.dropout_prob = dropout_prob
self.dropout = dropout
self.th = nn.Tanh()
# input: N x (L+3)
def forward(self, input):
xyz = input[:, -3:]
if input.shape[1] > 3 and self.latent_dropout:
latent_vecs = input[:, :-3]
latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
x = torch.cat([latent_vecs, xyz], 1)
else:
x = input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(layer))
if layer in self.latent_in:
x = torch.cat([x, input], 1)
elif layer != 0 and self.xyz_in_all:
x = torch.cat([x, xyz], 1)
x = lin(x)
# last layer Tanh
if layer == self.num_layers - 2 and self.use_tanh:
x = self.tanh(x)
if layer < self.num_layers - 2:
if (
self.norm_layers is not None
and layer in self.norm_layers
and not self.weight_norm
):
bn = getattr(self, "bn" + str(layer))
x = bn(x)
x = self.relu(x)
if self.dropout is not None and layer in self.dropout:
x = F.dropout(x, p=self.dropout_prob, training=self.training)
if hasattr(self, "th"):
x = self.th(x)
return x