You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.2 KiB
45 lines
1.2 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ANN_Model(nn.Module):
|
|
def __init__(self,input_features=8,out_features=72):
|
|
super().__init__()
|
|
self.fc1=nn.Linear(input_features,12)
|
|
self.fc2=nn.Linear(12,16)
|
|
self.fc3=nn.Linear(16,20)
|
|
self.fc4=nn.Linear(20,25)
|
|
|
|
self.fc5=nn.Linear(50,60)
|
|
self.fc6=nn.Linear(60,70)
|
|
self.fc7=nn.Linear(70,80)
|
|
self.fc8=nn.Linear(80,90)
|
|
self.fc9=nn.Linear(90,100)
|
|
self.fc10=nn.Linear(100,90)
|
|
self.fc11=nn.Linear(90,80)
|
|
|
|
|
|
self.out=nn.Linear(80,out_features)
|
|
|
|
|
|
def forward(self,x):
|
|
density=x[:,:25].reshape(x.shape[0],25)
|
|
displace = x[:,25:]
|
|
x = F.relu(self.fc1(displace))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.relu(self.fc3(x))
|
|
x = F.relu(self.fc4(x))
|
|
x = torch.hstack((density,x))
|
|
|
|
x = F.relu(self.fc5(x))
|
|
x = F.relu(self.fc6(x))
|
|
x = F.relu(self.fc7(x))
|
|
x = F.relu(self.fc8(x))
|
|
x = F.relu(self.fc9(x))
|
|
x = F.relu(self.fc10(x))
|
|
x = F.relu(self.fc11(x))
|
|
|
|
|
|
x = self.out(x)
|
|
|
|
return x
|