#include #include #include #include template __global__ void feconv_cuda_forward_kernel2( const torch::PackedTensorAccessor U, const torch::PackedTensorAccessor H8types, const torch::PackedTensorAccessor nodIdx, const torch::PackedTensorAccessor filters, torch::PackedTensorAccessor KU) { const int outidx = blockIdx.x / 41; // 0 - 17 const int Idxx = blockIdx.x % 41; // 0 - 40 //const int Idxx = threadIdx.x % 41; //const int Idxy = threadIdx.x / 41; // const int Idxx = blockIdx.x; const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.x][0][Idxx][Idxy][Idxz]; //const auto fkernels = filters[h8type]; scalar_t convresult = 0.0; int direction = outidx % 3; for (int ix= 0; ix < 3; ix++) { for (int j = 0; j < 27; j++) { int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) { convresult += U[threadIdx.x][outidx - direction + ix][uidx1][uidx2][uidx3] * filters[h8type][3 * direction + ix][j]; } } } KU[threadIdx.x][outidx][Idxx][Idxy][Idxz] = convresult; } template __global__ void feconv_cuda_forward_kernel( const torch::PackedTensorAccessor U, const torch::PackedTensorAccessor H8types, const torch::PackedTensorAccessor nodIdx, const torch::PackedTensorAccessor filters, torch::PackedTensorAccessor KU) { //const int Idxx = threadIdx.x % 41; //const int Idxy = threadIdx.x / 41; const int Idxx = blockIdx.x; const int Idxy = blockIdx.y; const int Idxz = blockIdx.z; const int h8type = H8types[threadIdx.y][0][Idxx][Idxy][Idxz]; const auto fkernels = filters[h8type]; scalar_t convresult = 0.0; int direction = threadIdx.x % 3; /**/ for (int ix= 0; ix < 3; ix++) { for (int j = 0; j < 27; j++) { int uidx1 = nodIdx[Idxx][Idxy][Idxz][j][0]; int uidx2 = nodIdx[Idxx][Idxy][Idxz][j][1]; int uidx3 = nodIdx[Idxx][Idxy][Idxz][j][2]; if ((uidx1+1)*(uidx2+1)*(uidx3+1)!=0) { convresult += U[threadIdx.y][threadIdx.x - direction + ix][uidx1][uidx2][uidx3] * filters[h8type][3 * direction + ix][j]; // convresult += U[threadIdx.y][threadIdx.x - direction + ix][uidx1][uidx2][uidx3] * filters[255][3 * direction + ix][j]; // convresult += filters[h8type][3 * direction + ix][j]; } } } //KU[blockIdx.x][blockIdx.y][threadIdx.x][threadIdx.y][threadIdx.z] = 1.0; KU[threadIdx.y][threadIdx.x][Idxx][Idxy][Idxz] = convresult; /* const int h8type = H8types[blockIdx.y][blockIdx.x][threadIdx.x][threadIdx.y][threadIdx.z]; const auto fkernels = filters[h8type]; scalar_t convresult = 0.0; //blockIdx.y % 3 == 0: filters[0:3] //blockIdx.y % 3 == 1: filters[3:6] //blockIdx.y % 3 == 2: filters[6:9] int direction = blockIdx.x % 3; for (int ix= 0; ix < 3; ix++) { for (int j = 0; j < 27; j++) { int uidx1 = nodIdx[threadIdx.x][threadIdx.y][threadIdx.z][j][0]; int uidx2 = nodIdx[threadIdx.x][threadIdx.y][threadIdx.z][j][1]; int uidx3 = nodIdx[threadIdx.x][threadIdx.y][threadIdx.z][j][2]; convresult += 1.0;//U[blockIdx.x][blockIdx.y - direction + ix][uidx1][uidx2][uidx3] * filters[h8type][3 * direction + ix][j]; } } // scalar_t convresult = h8type*1.0; //KU[blockIdx.x][blockIdx.y][threadIdx.x][threadIdx.y][threadIdx.z] = 1.0; KU[blockIdx.y][blockIdx.x][threadIdx.x][threadIdx.y][threadIdx.z] = convresult;*/ } std::vector //torch::Tensor feconv_cuda_forward( torch::Tensor U, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters) { const auto batch_size = U.size(0); auto KU = torch::zeros_like(U); // const dim3 blocks(41,41,41); // const dim3 threads(18,batch_size); const dim3 blocks(18*41,41,41); const dim3 threads(batch_size); //const dim3 blocks(18,batch_size); //const dim3 threads(41,41,41); // const dim3 blocks(batch_size,3); // const dim3 threads(11,11,11); /**/ AT_DISPATCH_FLOATING_TYPES(U.type(), "feconv_forward_cuda", ([&] { feconv_cuda_forward_kernel2<<>>( U.packed_accessor(), H8types.packed_accessor(), nodIdx.packed_accessor(), filters.packed_accessor(), KU.packed_accessor()); })); return {KU}; //return {U}; } std::vector // torch::Tensor feconv_cuda_backward( torch::Tensor gradU) { auto dU = torch::zeros_like(gradU); return {dU}; }