import torch import torch.nn as nn import torch.nn.functional as F class ANN_Model(nn.Module): def __init__(self,input_features=8,out_features=72): super().__init__() self.fc1=nn.Linear(input_features,12) self.fc2=nn.Linear(12,16) self.fc3=nn.Linear(16,20) self.fc4=nn.Linear(20,25) self.fc5=nn.Linear(50,60) self.fc6=nn.Linear(60,70) self.fc7=nn.Linear(70,80) self.fc8=nn.Linear(80,90) self.fc9=nn.Linear(90,100) self.fc10=nn.Linear(100,90) self.fc11=nn.Linear(90,80) self.out=nn.Linear(80,out_features) def forward(self,x): density=x[:,:25].reshape(x.shape[0],25) displace = x[:,25:] x = F.relu(self.fc1(displace)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = F.relu(self.fc4(x)) x = torch.hstack((density,x)) x = F.relu(self.fc5(x)) x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) x = F.relu(self.fc8(x)) x = F.relu(self.fc9(x)) x = F.relu(self.fc10(x)) x = F.relu(self.fc11(x)) x = self.out(x) return x