Support torch.dtype as parameter in pybind11 cpp extension. (#126865)

Support torch.dtype as parameter in pybind11 cpp extension.
Example:
`
cpp_extension.my_ops(self, other, torch.dtype)
`

@ezyang @bdhirsh
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126865
Approved by: https://github.com/ezyang
This commit is contained in:
Shan19900305
2024-05-29 23:19:30 +00:00
committed by PyTorch MergeBot
parent 8ea1dc8748
commit 7931eee5c5
5 changed files with 46 additions and 12 deletions

View File

@ -2,6 +2,7 @@
// test include_dirs in setuptools.setup with relative path
#include <tmp.h>
#include <ATen/OpMathType.h>
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
return x.sigmoid() + y.sigmoid();
@ -31,6 +32,10 @@ torch::Tensor random_tensor() {
return torch::randn({1});
}
at::ScalarType get_math_type(at::ScalarType other) {
return at::toOpMathType(other);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
m.def(
@ -52,4 +57,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_symint", []() { return c10::SymInt(1); });
m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
m.def("get_tensor", []() { return random_tensor(); });
m.def("get_math_type", &get_math_type);
}

View File

@ -55,6 +55,10 @@ class TestCppExtensionAOT(common.TestCase):
y = torch.randn(4, 4)
z = cpp_extension.sigmoid_add(x, y)
self.assertEqual(z, x.sigmoid() + y.sigmoid())
# test pybind support torch.dtype cast.
self.assertEqual(
str(torch.float32), str(cpp_extension.get_math_type(torch.half))
)
def test_extension_module(self):
mm = cpp_extension.MatrixMultiplier(4, 8)

View File

@ -10,6 +10,7 @@
#include <pybind11/stl.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/MemoryFormat.h>
@ -189,6 +190,35 @@ struct type_caster<at::Device> {
}
};
template <>
struct type_caster<at::ScalarType> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
// PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
// cannot be default-initialized, we provide this constructor to explicitly
// initialize that field. The value doesn't matter as it will be overwritten
// after a successful call to load.
type_caster() : value(at::kFloat) {}
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (THPDtype_Check(obj)) {
value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
return true;
}
return false;
}
static handle cast(
const at::ScalarType& src,
return_value_policy /* policy */,
handle /* parent */) {
return Py_NewRef(torch::getTHPDtype(src));
}
};
template <>
struct type_caster<c10::Stream> {
public:
@ -206,7 +236,7 @@ struct type_caster<c10::Stream> {
if (THPStream_Check(obj)) {
value = c10::Stream::unpack3(
((THPStream*)obj)->stream_id,
((THPStream*)obj)->device_index,
static_cast<c10::DeviceIndex>(((THPStream*)obj)->device_index),
static_cast<c10::DeviceType>(((THPStream*)obj)->device_type));
return true;
}

View File

@ -1,14 +1,11 @@
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <torch/csrc/utils/tensor_types.h>
namespace torch {
namespace utils {
namespace torch::utils {
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
switch (scalarType) {
@ -125,5 +122,4 @@ void initializeDtypes() {
}
}
} // namespace utils
} // namespace torch
} // namespace torch::utils

View File

@ -1,15 +1,13 @@
#pragma once
#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <string>
#include <tuple>
namespace torch {
namespace utils {
namespace torch::utils {
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType);
void initializeDtypes();
} // namespace utils
} // namespace torch
} // namespace torch::utils