You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
6.0 KiB

3 months ago
import numpy as np
import torch.nn as nn
import torch
from torch.autograd import grad
#borrowed from siren paper
class Sine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
return torch.sin(30 * input)
def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0][:, -3:]
return points_grad
class NHRepNet(nn.Module):
def __init__(
self,
d_in,
dims_sdf,
csg_tree,
skip_in=(),
flag_convex = True,
geometric_init=True, #set false for siren
radius_init=1,
beta=100,
flag_output = 0,
n_branch = 2
):
super().__init__()
self.flag_output = flag_output #0: all 1: h, 2: f, 3: g
self.n_branch = n_branch
self.csg_tree = csg_tree
self.flag_convex = flag_convex
self.skip_in = skip_in
dims_sdf = [d_in] + dims_sdf + [self.n_branch]
self.sdf_layers = len(dims_sdf)
for layer in range(0, len(dims_sdf) - 1):
if layer + 1 in skip_in:
out_dim = dims_sdf[layer + 1] - d_in
else:
out_dim = dims_sdf[layer + 1]
lin = nn.Linear(dims_sdf[layer], out_dim)
if geometric_init:
if layer == self.sdf_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001)
torch.nn.init.constant_(lin.bias, -radius_init)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
setattr(self, "sdf_"+str(layer), lin)
if geometric_init:
if beta > 0:
self.activation = nn.Softplus(beta=beta)
# vanilla relu
else:
self.activation = nn.ReLU()
else:
#siren
self.activation = Sine()
self.final_activation = nn.ReLU()
# composite f_i to h
def nested_cvx_output(self, value_matrix, list_operation, cvx_flag=True):
list_value = []
for v in list_operation:
if type(v) != list:
list_value.append(v)
op_mat = torch.zeros(value_matrix.shape[1],len(list_value)).cuda()
for i in range(len(list_value)):
op_mat[list_value[i]][i] = 1.0
mat_mul = torch.matmul(value_matrix, op_mat)
if len(list_operation) == len(list_value):
# leaf node
if cvx_flag:
return torch.max(mat_mul, 1)[0].unsqueeze(1)
else:
return torch.min(mat_mul, 1)[0].unsqueeze(1)
else:
list_output = [mat_mul]
for v in list_operation:
if type(v) == list:
list_output.append(self.nested_cvx_output(value_matrix, v, not cvx_flag))
if cvx_flag:
return torch.max(torch.cat(list_output, 1), 1)[0].unsqueeze(1)
else:
return torch.min(torch.cat(list_output, 1), 1)[0].unsqueeze(1)
def min_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def max_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
#r-function blending
def nested_cvx_output_soft_blend(self, value_matrix, list_operation, cvx_flag=True):
rho = 0.05
list_value = []
for v in list_operation:
if type(v) != list:
list_value.append(v)
op_mat = torch.zeros(value_matrix.shape[1],len(list_value)).cuda()
for i in range(len(list_value)):
op_mat[list_value[i]][i] = 1.0
mat_mul = torch.matmul(value_matrix, op_mat)
if len(list_operation) == len(list_value):
# leaf node
if cvx_flag:
return self.max_soft_blend(mat_mul, rho)
else:
return self.min_soft_blend(mat_mul, rho)
else:
list_output = [mat_mul]
for v in list_operation:
if type(v) == list:
list_output.append(self.nested_cvx_output_soft_blend(value_matrix, v, not cvx_flag))
if cvx_flag:
return self.max_soft_blend(torch.cat(list_output, 1), rho)
else:
return self.min_soft_blend(torch.cat(list_output, 1), rho)
def forward(self, input):
x = input
for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in:
x = torch.cat([x, input], -1) / np.sqrt(2)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x #all f_i
h = self.nested_cvx_output(output_value, self.csg_tree, self.flag_convex)
# r-function blending
# h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex)
if self.flag_output == 0:
return torch.cat((h, output_value), 1) # return all
elif self.flag_output == 1:
return h #return h
else:
return output_value[:, self.flag_output - 2] #return f_i