import torch import torch.nn as nn import torch.nn.functional as F class ANN_Model(nn.Module): def __init__(self,in_dim=25,l1=36,l2=36*2,l3=36*2,l4=36*4,l5=36*4,l6=36*8,l7=36*8,l8=36*16,out_dim=36*2*4*2): super().__init__() self.fc1=nn.Linear(in_dim,l1) self.fc2=nn.Linear(l1,l2) self.fc3=nn.Linear(l2,l3) self.fc4=nn.Linear(l3,l4) self.fc5=nn.Linear(l4,l5) self.fc6=nn.Linear(l5,l6) self.fc7=nn.Linear(l6,l7) self.fc8=nn.Linear(l7,l8) self.out=nn.Linear(l8,out_dim) # -> 576 def forward(self,x): density = x[:25] displace = x[25:] x = F.relu(self.fc1(density)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = F.relu(self.fc4(x)) # x = torch.hstack((density,x)) x = F.relu(self.fc5(x)) x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) x = F.relu(self.fc8(x)) x = self.out(x) return x