You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
1.6 KiB

11 months ago
#include <torch/extension.h>
#include <vector>
std::vector<torch::Tensor> feconv_cuda_forward(
torch::Tensor U,
torch::Tensor H8types,
torch::Tensor nodIdx,
torch::Tensor filters);
std::vector<torch::Tensor> feconv_cuda_backward(
torch::Tensor gradV,
torch::Tensor H8types,
torch::Tensor nodIdx,
torch::Tensor filters);
/*
// NOTE: torch_ASSERT has become torch_CHECK on master after 0.4.
#define CHECK_CUDA(x) torch_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) torch_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
*/
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> feconv_forward(
torch::Tensor U,
torch::Tensor H8types,
torch::Tensor nodIdx,
torch::Tensor filters){
CHECK_INPUT(U);
CHECK_INPUT(H8types);
CHECK_INPUT(nodIdx);
CHECK_INPUT(filters);
//torch::DeviceGuard guard(U.device());
return feconv_cuda_forward(U,H8types,nodIdx,filters);
}
std::vector<torch::Tensor> feconv_backward(
torch::Tensor gradV,
torch::Tensor H8types,
torch::Tensor nodIdx,
torch::Tensor filters){
CHECK_INPUT(gradV);
CHECK_INPUT(H8types);
CHECK_INPUT(nodIdx);
CHECK_INPUT(filters);
return feconv_cuda_backward(gradV,H8types,nodIdx,filters);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &feconv_forward, "FECONV forward (CUDA)");
m.def("backward", &feconv_backward, "FECONV backward (CUDA)");
}