You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

34 lines
1.0 KiB

# -*- coding: utf-8 -*-
"""
Created on Sat Feb 6 22:22:39 2021
@author: Liangchao Zhu
"""
import numpy as np
import torch
import torch.nn.functional as F
def kernelForType(datatype = np.float32):
filterKernel = np.zeros((2,2,2),dtype = datatype)
for i in range(2):
for j in range(2):
for k in range(2):
Eleindex = i + 2*j + 4*k
filterKernel[i,j,k] = 2**Eleindex
return filterKernel
def typeH8(rho):
if rho.dtype == torch.double:
datatype = np.float64
else:
datatype = np.float32
filterKernel = torch.from_numpy(kernelForType(datatype)[np.newaxis,np.newaxis]).to(rho.device)
# print(filterKernel.dtype)
H8Types = F.conv3d(rho,filterKernel,padding = 1)
return H8Types.int()
def typeH8_(rho,filterKernel):
H8Types = F.conv3d(rho,filterKernel,padding = 1)
return H8Types.int()
if __name__ == "__main__":
rho = torch.ones((8,1,40,40,40),dtype = torch.float32)
rho = torch.ones((8,1,4,4,4),dtype = torch.float32)
print(rho.shape)
H8Types = typeH8(rho)