mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Close https://github.com/pytorch/pytorch/issues/57543 Doc: check `Relocatable device code linking:` in https://docs-preview.pytorch.org/78225/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension Pull Request resolved: https://github.com/pytorch/pytorch/pull/78225 Approved by: https://github.com/ezyang, https://github.com/malfet
22 lines
748 B
C++
22 lines
748 B
C++
#include <torch/extension.h>
|
|
|
|
// Declare the function from cuda_dlink_extension.cu.
|
|
void add_cuda(const float* a, const float* b, float* output, int size);
|
|
|
|
at::Tensor add(at::Tensor a, at::Tensor b) {
|
|
TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor");
|
|
TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor");
|
|
TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor");
|
|
TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor");
|
|
TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size");
|
|
|
|
at::Tensor output = at::empty_like(a);
|
|
add_cuda(a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), a.numel());
|
|
|
|
return output;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("add", &add, "a + b");
|
|
}
|