#include #include #include #include template __global__ void feconv_cuda_forward_kernel( const torch::PackedTensorAccessor32 U, const torch::PackedTensorAccessor32 H8types, const torch::PackedTensorAccessor32 nodIdx, const torch::PackedTensorAccessor32 filters, torch::PackedTensorAccessor32 KU) { const int outidx = blockIdx.x / 41; // 0 - 17 //const int Idxx = threadIdx.x % 41; //const int Idxy = threadIdx.x / 41; const int Idxx = blockIdx.x % 41; // 0 - 40 const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.x][0][Idxx][Idxy][Idxz]; const auto fkernels = filters[h8type]; scalar_t convresult = 0.0; int direction = outidx % 3; for (int j = 0; j < 27; j++) { int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) { for (int ix= 0; ix < 3; ix++) { convresult += U[threadIdx.x][outidx - direction + ix][uidx1][uidx2][uidx3] * fkernels[direction][ix][j]; } } } KU[threadIdx.x][outidx][Idxx][Idxy][Idxz] = convresult; } template __global__ void feconv_cuda_backward_kernel( const torch::PackedTensorAccessor32 gradV, const torch::PackedTensorAccessor32 H8types, const torch::PackedTensorAccessor32 nodIdx, const torch::PackedTensorAccessor32 filters, torch::PackedTensorAccessor32 d_U) { const int outidx = blockIdx.x / 41; // 0 - 17 const int Idxx = blockIdx.x % 41; // 0 - 40 const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.x][0][Idxx][Idxy][Idxz]; const auto fkernels = filters[h8type]; scalar_t convresult = 0.0; int direction = outidx % 3; for (int j = 0; j < 27; j++) { int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) { for (int ix= 0; ix < 3; ix++) { convresult += gradV[threadIdx.x][outidx - direction + ix][uidx1][uidx2][uidx3] * fkernels[direction][ix][j]; } } } d_U[threadIdx.x][outidx][Idxx][Idxy][Idxz] = convresult; } std::vector feconv_cuda_forward( torch::Tensor U, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters) { const auto batch_size = U.size(0); auto KU = torch::zeros_like(U); // const dim3 blocks(41,41,41); // const dim3 threads(18,batch_size); const dim3 blocks(18*41,41,41); const dim3 threads(batch_size); //const dim3 blocks(18,batch_size); //const dim3 threads(41,41,41); AT_DISPATCH_FLOATING_TYPES_AND_HALF(U.scalar_type(), "feconv_forward_cuda", ([&] { feconv_cuda_forward_kernel<<>>( U.packed_accessor32(), H8types.packed_accessor32(), nodIdx.packed_accessor32(), filters.packed_accessor32(), KU.packed_accessor32()); })); return {KU}; } std::vector feconv_cuda_backward( torch::Tensor gradV, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters) { const auto batch_size = gradV.size(0); auto d_U = torch::zeros_like(gradV); const dim3 blocks(18*41,41,41); const dim3 threads(batch_size); AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradV.scalar_type(), "feconv_backward_cuda", ([&] { feconv_cuda_backward_kernel<<>>( gradV.packed_accessor32(), H8types.packed_accessor32(), nodIdx.packed_accessor32(), filters.packed_accessor32(), d_U.packed_accessor32()); })); return {d_U}; }