#include #include std::vector feconv_cuda_forward( torch::Tensor U, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters); std::vector feconv_cuda_backward( torch::Tensor gradV, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters); /* // NOTE: torch_ASSERT has become torch_CHECK on master after 0.4. #define CHECK_CUDA(x) torch_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) torch_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) */ #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) std::vector feconv_forward( torch::Tensor U, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters){ CHECK_INPUT(U); CHECK_INPUT(H8types); CHECK_INPUT(nodIdx); CHECK_INPUT(filters); //torch::DeviceGuard guard(U.device()); return feconv_cuda_forward(U,H8types,nodIdx,filters); } std::vector feconv_backward( torch::Tensor gradV, torch::Tensor H8types, torch::Tensor nodIdx, torch::Tensor filters){ CHECK_INPUT(gradV); CHECK_INPUT(H8types); CHECK_INPUT(nodIdx); CHECK_INPUT(filters); return feconv_cuda_backward(gradV,H8types,nodIdx,filters); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &feconv_forward, "FECONV forward (CUDA)"); m.def("backward", &feconv_backward, "FECONV backward (CUDA)"); }