You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.1 KiB
63 lines
2.1 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class CNN_Model(nn.Module):
|
|
def __init__(self,input_features=8,out_features=72):
|
|
super().__init__()
|
|
|
|
self.upsample25x=nn.Upsample(scale_factor=2.5, mode='bilinear')
|
|
# self.conv55=nn.Conv2d(3,3,3,padding=1) # keep B 3 5 5
|
|
# self.conv56=nn.Conv2d(3,3,4,padding=2) # B 3 5 5 -> B 3 5 6
|
|
|
|
# self.conv66_1=nn.Conv2d(3,5,3,padding=1) # B 3 6 6 -> B 64 6 6
|
|
# self.conv66_2=nn.Conv2d(5,8,3,padding=1) # B 64 6 6 -> B 32 6 6
|
|
# self.conv66_3=nn.Conv2d(8,5,3,padding=1) # B 32 6 6 -> B 8 6 6
|
|
# self.conv66_4=nn.Conv2d(5,2,3,padding=1) # B 8 6 6 -> B 2 6 6
|
|
|
|
self.conv55=nn.Conv2d(3,2,3,padding=1) # B 3 5 5 -> B 2 5 5
|
|
self.fc1=nn.Linear(50,70)
|
|
self.fc2=nn.Linear(70,90)
|
|
self.fc3=nn.Linear(90,110)
|
|
self.fc4=nn.Linear(110,130)
|
|
self.fc5=nn.Linear(130,100)
|
|
self.fc6=nn.Linear(100,80)
|
|
self.fc7=nn.Linear(80,72)
|
|
|
|
def forward(self,x):
|
|
B=x.shape[0]
|
|
density=x[:,:25]
|
|
density=density.reshape(B,1,5,5) # B 1 5 5
|
|
|
|
|
|
displace = x[:,25:]
|
|
displace = displace.reshape(B,2,2,2) # B 2 2 2(C)
|
|
displace = displace.permute(0,3,1,2) #更换张量维度顺序为->B C W H
|
|
displace = self.upsample25x(displace) #升维度 -> B 2 5 5
|
|
|
|
# x = torch.cat((displace,density),1)
|
|
|
|
|
|
# x = F.relu(self.conv56(x))
|
|
# x = F.relu(self.conv66_1(x))
|
|
# x = F.relu(self.conv66_2(x))
|
|
# x = F.relu(self.conv66_3(x))
|
|
# x = F.relu(self.conv66_4(x))
|
|
|
|
# x = x.permute(0,2,3,1)
|
|
# x = x.reshape(B,72)
|
|
|
|
u = torch.mul(displace[:,0,:,:],density[:,0,:,:]).reshape(B,25)
|
|
v = torch.mul(displace[:,1,:,:],density[:,0,:,:]).reshape(B,25)
|
|
x = torch.cat((u,v),1)
|
|
|
|
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.relu(self.fc3(x))
|
|
x = F.relu(self.fc4(x))
|
|
x = F.relu(self.fc5(x))
|
|
x = F.relu(self.fc6(x))
|
|
x = F.relu(self.fc7(x))
|
|
|
|
return x
|