import torch import torch.nn as nn import torch.nn.functional as F class CNN_Model(nn.Module): def __init__(self,input_features=8,out_features=72): super().__init__() self.upsample25x=nn.Upsample(scale_factor=2.5, mode='bilinear') # self.conv55=nn.Conv2d(3,3,3,padding=1) # keep B 3 5 5 # self.conv56=nn.Conv2d(3,3,4,padding=2) # B 3 5 5 -> B 3 5 6 # self.conv66_1=nn.Conv2d(3,5,3,padding=1) # B 3 6 6 -> B 64 6 6 # self.conv66_2=nn.Conv2d(5,8,3,padding=1) # B 64 6 6 -> B 32 6 6 # self.conv66_3=nn.Conv2d(8,5,3,padding=1) # B 32 6 6 -> B 8 6 6 # self.conv66_4=nn.Conv2d(5,2,3,padding=1) # B 8 6 6 -> B 2 6 6 self.conv55=nn.Conv2d(3,2,3,padding=1) # B 3 5 5 -> B 2 5 5 self.fc1=nn.Linear(50,70) self.fc2=nn.Linear(70,90) self.fc3=nn.Linear(90,110) self.fc4=nn.Linear(110,130) self.fc5=nn.Linear(130,100) self.fc6=nn.Linear(100,80) self.fc7=nn.Linear(80,72) def forward(self,x): B=x.shape[0] density=x[:,:25] density=density.reshape(B,1,5,5) # B 1 5 5 displace = x[:,25:] displace = displace.reshape(B,2,2,2) # B 2 2 2(C) displace = displace.permute(0,3,1,2) #更换张量维度顺序为->B C W H displace = self.upsample25x(displace) #升维度 -> B 2 5 5 # x = torch.cat((displace,density),1) # x = F.relu(self.conv56(x)) # x = F.relu(self.conv66_1(x)) # x = F.relu(self.conv66_2(x)) # x = F.relu(self.conv66_3(x)) # x = F.relu(self.conv66_4(x)) # x = x.permute(0,2,3,1) # x = x.reshape(B,72) u = torch.mul(displace[:,0,:,:],density[:,0,:,:]).reshape(B,25) v = torch.mul(displace[:,1,:,:],density[:,0,:,:]).reshape(B,25) x = torch.cat((u,v),1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) x = F.relu(self.fc4(x)) x = F.relu(self.fc5(x)) x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) return x