76 lines
2.2 KiB
76 lines
2.2 KiB
#### ENCODER_Model
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class AutoEncoder_Model(nn.Module):
|
|
def __init__(self,input_features=33,out_features=72):
|
|
super().__init__()
|
|
self.upsample25x=nn.Upsample(scale_factor=2.5, mode='bilinear')
|
|
|
|
self.fc1=nn.Linear(75,100)
|
|
self.fc2=nn.Linear(100,150)
|
|
self.fc3=nn.Linear(150,200)
|
|
self.fc4=nn.Linear(200,300)
|
|
self.fc5=nn.Linear(300,400)
|
|
self.fc6=nn.Linear(400,500)
|
|
self.fc7=nn.Linear(500,576)
|
|
|
|
self.fc8=nn.Linear(576,500)
|
|
self.fc9=nn.Linear(500,400)
|
|
self.fc10=nn.Linear(400,300)
|
|
self.fc11=nn.Linear(300,200)
|
|
self.fc12=nn.Linear(200,150)
|
|
self.fc13=nn.Linear(150,100)
|
|
self.fc14=nn.Linear(100,72)
|
|
|
|
|
|
def forward(self,x):
|
|
B=x.shape[0]
|
|
density=x[:,:25]
|
|
density=density.reshape(B,1,5,5) # B 1(C) 5 5
|
|
|
|
u = x[:,25:29].reshape(B,1,2,2) # B 1(C) 2 2
|
|
v = x[:,29:].reshape(B,1,2,2) # B 1(C) 2 2
|
|
|
|
displace = torch.cat((u,v),1) # B 2(C) 2 2
|
|
|
|
displace = self.upsample25x(displace) #升维度 -> B 2 5 5
|
|
|
|
# 1.矩阵相乘做耦合
|
|
# u = torch.mul(displace[:,0,:,:],density[:,0,:,:])
|
|
# v = torch.mul(displace[:,1,:,:],density[:,0,:,:])
|
|
# x = torch.stack((u,v),1) # B 2 5 5
|
|
# x = x.reshape(B,50)
|
|
|
|
# 2.卷积做耦合
|
|
# self.conv55=nn.Conv2d(3,2,3,padding=1) # B 3 5 5 -> B 2 5 5
|
|
#
|
|
|
|
# 3.直接 cat 接上
|
|
x = torch.cat((displace,density),1)
|
|
x = x.reshape(B,75)
|
|
|
|
x = torch.autograd.Variable(x,requires_grad=True)
|
|
|
|
|
|
# Encode
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.relu(self.fc3(x))
|
|
x = F.relu(self.fc4(x))
|
|
x = F.relu(self.fc5(x))
|
|
x = F.relu(self.fc6(x))
|
|
x = F.relu(self.fc7(x))
|
|
shape_func=x.clone()
|
|
|
|
# Decode
|
|
x = F.relu(self.fc8(x))
|
|
x = F.relu(self.fc9(x))
|
|
x = F.relu(self.fc10(x))
|
|
x = F.relu(self.fc11(x))
|
|
x = F.relu(self.fc12(x))
|
|
x = F.relu(self.fc13(x))
|
|
x = F.relu(self.fc14(x))
|
|
|
|
return x, shape_func
|