mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support torch.dtype as parameter in pybind11 cpp extension. (#126865)
Support torch.dtype as parameter in pybind11 cpp extension. Example: ` cpp_extension.my_ops(self, other, torch.dtype) ` @ezyang @bdhirsh Co-authored-by: Edward Z. Yang <ezyang@mit.edu> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126865 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ea1dc8748
commit
7931eee5c5
@ -2,6 +2,7 @@
|
||||
|
||||
// test include_dirs in setuptools.setup with relative path
|
||||
#include <tmp.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||
return x.sigmoid() + y.sigmoid();
|
||||
@ -31,6 +32,10 @@ torch::Tensor random_tensor() {
|
||||
return torch::randn({1});
|
||||
}
|
||||
|
||||
at::ScalarType get_math_type(at::ScalarType other) {
|
||||
return at::toOpMathType(other);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
||||
m.def(
|
||||
@ -52,4 +57,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("get_symint", []() { return c10::SymInt(1); });
|
||||
m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
|
||||
m.def("get_tensor", []() { return random_tensor(); });
|
||||
m.def("get_math_type", &get_math_type);
|
||||
}
|
||||
|
@ -55,6 +55,10 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
y = torch.randn(4, 4)
|
||||
z = cpp_extension.sigmoid_add(x, y)
|
||||
self.assertEqual(z, x.sigmoid() + y.sigmoid())
|
||||
# test pybind support torch.dtype cast.
|
||||
self.assertEqual(
|
||||
str(torch.float32), str(cpp_extension.get_math_type(torch.half))
|
||||
)
|
||||
|
||||
def test_extension_module(self):
|
||||
mm = cpp_extension.MatrixMultiplier(4, 8)
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
@ -189,6 +190,35 @@ struct type_caster<at::Device> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<at::ScalarType> {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
|
||||
|
||||
// PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
|
||||
// cannot be default-initialized, we provide this constructor to explicitly
|
||||
// initialize that field. The value doesn't matter as it will be overwritten
|
||||
// after a successful call to load.
|
||||
type_caster() : value(at::kFloat) {}
|
||||
|
||||
bool load(handle src, bool) {
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPDtype_Check(obj)) {
|
||||
value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(
|
||||
const at::ScalarType& src,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
return Py_NewRef(torch::getTHPDtype(src));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<c10::Stream> {
|
||||
public:
|
||||
@ -206,7 +236,7 @@ struct type_caster<c10::Stream> {
|
||||
if (THPStream_Check(obj)) {
|
||||
value = c10::Stream::unpack3(
|
||||
((THPStream*)obj)->stream_id,
|
||||
((THPStream*)obj)->device_index,
|
||||
static_cast<c10::DeviceIndex>(((THPStream*)obj)->device_index),
|
||||
static_cast<c10::DeviceType>(((THPStream*)obj)->device_type));
|
||||
return true;
|
||||
}
|
||||
|
@ -1,14 +1,11 @@
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/autograd/generated/VariableType.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
#include <torch/csrc/utils/tensor_types.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
namespace torch::utils {
|
||||
|
||||
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
|
||||
switch (scalarType) {
|
||||
@ -125,5 +122,4 @@ void initializeDtypes() {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
} // namespace torch::utils
|
||||
|
@ -1,15 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
namespace torch::utils {
|
||||
|
||||
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType);
|
||||
|
||||
void initializeDtypes();
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
} // namespace torch::utils
|
||||
|
Reference in New Issue
Block a user