38 lines
990 B
38 lines
990 B
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class ANN_b_spline_Model(nn.Module):
|
|
def __init__(self,in_dim=25+8,l1=36,l2=36*2,l3=36*3,l4=36*4,l5=36*5,l6=36*6,l7=36*5,l8=36*4,out_dim=36*2):
|
|
super().__init__()
|
|
self.fc1=nn.Linear(in_dim,l1)
|
|
self.fc2=nn.Linear(l1,l2)
|
|
self.fc3=nn.Linear(l2,l3)
|
|
self.fc4=nn.Linear(l3,l4)
|
|
|
|
self.fc5=nn.Linear(l4,l5)
|
|
self.fc6=nn.Linear(l5,l6)
|
|
self.fc7=nn.Linear(l6,l7)
|
|
self.fc8=nn.Linear(l7,l8)
|
|
|
|
|
|
self.out=nn.Linear(l8,out_dim) # -> 6*6*2
|
|
|
|
|
|
def forward(self,x):
|
|
density = x[:25]
|
|
displace = x[25:]
|
|
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.relu(self.fc3(x))
|
|
x = F.relu(self.fc4(x))
|
|
|
|
x = F.relu(self.fc5(x))
|
|
x = F.relu(self.fc6(x))
|
|
x = F.relu(self.fc7(x))
|
|
x = F.relu(self.fc8(x))
|
|
|
|
x = self.out(x)
|
|
|
|
return x
|