@ -3,6 +3,7 @@ import torch.nn as nn
import torch
import torch
from torch . autograd import grad
from torch . autograd import grad
from utils . logger import logger
#borrowed from siren paper
#borrowed from siren paper
class Sine ( nn . Module ) :
class Sine ( nn . Module ) :
def __init ( self ) :
def __init ( self ) :
@ -143,10 +144,18 @@ class NHRepNet(nn.Module):
def forward ( self , input ) :
def forward ( self , input ) :
'''
: param input : 2 D array of floats , each row represents a point in 3 D space
e . g . input = torch . tensor ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
: return : 2 D array of floats , each row represents a point in 3 D space
e . g . output = torch . tensor ( [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] )
: note : The input tensor should have a shape of ( N , 3 ) , where N is the number of points . usually N = 100000
The output tensor will have a shape of ( N , 5 ) , where each row corresponds to the processed output for the respective input point .
'''
# logger.info(f"input shape: {input.shape}")
x = input
x = input
for layer in range ( 0 , self . sdf_layers - 1 ) :
for layer in range ( 0 , self . sdf_layers - 1 ) :
lin = getattr ( self , " sdf_ " + str ( layer ) )
lin = getattr ( self , " sdf_ " + str ( layer ) )
if layer in self . skip_in :
if layer in self . skip_in :
x = torch . cat ( [ x , input ] , - 1 ) / np . sqrt ( 2 )
x = torch . cat ( [ x , input ] , - 1 ) / np . sqrt ( 2 )
@ -160,8 +169,13 @@ class NHRepNet(nn.Module):
# h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex)
# h = self.nested_cvx_output_soft_blend(output_value, self.csg_tree, self.flag_convex)
if self . flag_output == 0 :
if self . flag_output == 0 :
# logger.info(f"h shape: {h.shape}")
# logger.info(f"output_value shape: {output_value.shape}")
return torch . cat ( ( h , output_value ) , 1 ) # return all
return torch . cat ( ( h , output_value ) , 1 ) # return all
elif self . flag_output == 1 :
elif self . flag_output == 1 :
# logger.info(f"h shape: {h.shape}")
return h #return h
return h #return h
else :
else :
# logger.info(f"output_value shape: {output_value.shape}")
# logger.info(f"flag_output: {self.flag_output}")
return output_value [ : , self . flag_output - 2 ] #return f_i
return output_value [ : , self . flag_output - 2 ] #return f_i