#### ENCODER_Model import torch import torch.nn as nn import torch.nn.functional as F class AutoEncoder_Model(nn.Module): def __init__(self,input_features=33,out_features=72): super().__init__() self.upsample25x=nn.Upsample(scale_factor=2.5, mode='bilinear') self.fc1=nn.Linear(75,100) self.fc2=nn.Linear(100,150) self.fc3=nn.Linear(150,200) self.fc4=nn.Linear(200,300) self.fc5=nn.Linear(300,400) self.fc6=nn.Linear(400,500) self.fc7=nn.Linear(500,576) self.fc8=nn.Linear(576,500) self.fc9=nn.Linear(500,400) self.fc10=nn.Linear(400,300) self.fc11=nn.Linear(300,200) self.fc12=nn.Linear(200,150) self.fc13=nn.Linear(150,100) self.fc14=nn.Linear(100,72) def forward(self,x): B=x.shape[0] density=x[:,:25] density=density.reshape(B,1,5,5) # B 1(C) 5 5 u = x[:,25:29].reshape(B,1,2,2) # B 1(C) 2 2 v = x[:,29:].reshape(B,1,2,2) # B 1(C) 2 2 displace = torch.cat((u,v),1) # B 2(C) 2 2 displace = self.upsample25x(displace) #升维度 -> B 2 5 5 # 1.矩阵相乘做耦合 # u = torch.mul(displace[:,0,:,:],density[:,0,:,:]) # v = torch.mul(displace[:,1,:,:],density[:,0,:,:]) # x = torch.stack((u,v),1) # B 2 5 5 # x = x.reshape(B,50) # 2.卷积做耦合 # self.conv55=nn.Conv2d(3,2,3,padding=1) # B 3 5 5 -> B 2 5 5 # # 3.直接 cat 接上 x = torch.cat((displace,density),1) x = x.reshape(B,75) x = torch.autograd.Variable(x,requires_grad=True) # Encode x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = F.relu(self.fc4(x)) x = F.relu(self.fc5(x)) x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) shape_func=x.clone() # Decode x = F.relu(self.fc8(x)) x = F.relu(self.fc9(x)) x = F.relu(self.fc10(x)) x = F.relu(self.fc11(x)) x = F.relu(self.fc12(x)) x = F.relu(self.fc13(x)) x = F.relu(self.fc14(x)) return x, shape_func