#include #include #include #include template __global__ void feconvR_cuda_forward_kernel( const torch::PackedTensorAccessor32 U, const torch::PackedTensorAccessor32 H8types, const torch::PackedTensorAccessor32 filters, torch::PackedTensorAccessor32 KU) { const int outidx = blockIdx.x / 41; // 0 - 17 //const int Idxx = threadIdx.x % 41; //const int Idxy = threadIdx.x / 41; const int Idxx = blockIdx.x % 41; // 0 - 40 const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.x][0][Idxx][Idxy][Idxz]; // const auto fkernels = filters[h8type]; // scalar_t convresult = 0.0; // int direction = outidx % 3; // for (int j = 0; j < 27; j++) // { // int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; // int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; // int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; // if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) // { // for (int ix= 0; ix < 3; ix++) // { // convresult += U[threadIdx.x][outidx - direction + ix][uidx1][uidx2][uidx3] * fkernels[direction][ix][j]; // } // } // } // KU[threadIdx.x][outidx][Idxx][Idxy][Idxz] = convresult; KU[threadIdx.x][outidx][Idxx][Idxy][Idxz] = U[threadIdx.x][outidx][Idxx][Idxy][Idxz] * filters[h8type][outidx]; } template __global__ void feconvR_cuda_backward_kernel( const torch::PackedTensorAccessor32 gradV, const torch::PackedTensorAccessor32 H8types, const torch::PackedTensorAccessor32 filters, torch::PackedTensorAccessor32 d_U) { const int outidx = blockIdx.x / 41; // 0 - 17 const int Idxx = blockIdx.x % 41; // 0 - 40 const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.x][0][Idxx][Idxy][Idxz]; // const auto fkernels = filters[h8type]; // scalar_t convresult = 0.0; // int direction = outidx % 3; // for (int j = 0; j < 27; j++) // { // int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; // int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; // int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; // if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) // { // for (int ix= 0; ix < 3; ix++) // { // convresult += gradV[threadIdx.x][outidx - direction + ix][uidx1][uidx2][uidx3] * fkernels[direction][ix][j]; // } // } // } // d_U[threadIdx.x][outidx][Idxx][Idxy][Idxz] = convresult; d_U[threadIdx.x][outidx][Idxx][Idxy][Idxz] = gradV[threadIdx.x][outidx][Idxx][Idxy][Idxz] * filters[h8type][outidx]; } std::vector feconvR_cuda_forward( torch::Tensor U, torch::Tensor H8types, torch::Tensor filters) { const auto batch_size = U.size(0); auto KU = torch::zeros_like(U); // const dim3 blocks(41,41,41); // const dim3 threads(18,batch_size); const dim3 blocks(18*41,41,41); const dim3 threads(batch_size); //const dim3 blocks(18,batch_size); //const dim3 threads(41,41,41); AT_DISPATCH_FLOATING_TYPES_AND_HALF(U.scalar_type(), "feconvR_forward_cuda", ([&] { feconvR_cuda_forward_kernel<<>>( U.packed_accessor32(), H8types.packed_accessor32(), filters.packed_accessor32(), KU.packed_accessor32()); })); return {KU}; } std::vector feconvR_cuda_backward( torch::Tensor gradV, torch::Tensor H8types, torch::Tensor filters) { const auto batch_size = gradV.size(0); auto d_U = torch::zeros_like(gradV); const dim3 blocks(18*41,41,41); const dim3 threads(batch_size); AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradV.scalar_type(), "feconvR_backward_cuda", ([&] { feconvR_cuda_backward_kernel<<>>( gradV.packed_accessor32(), H8types.packed_accessor32(), filters.packed_accessor32(), d_U.packed_accessor32()); })); return {d_U}; }