mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable Python bindings for UntypedStorage (#68945)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68945 This PR enables the Python conversion functions for `Storage` (specifically `UntypedStorage`) and also cleans up some remnants of the deprecated typed storages from `DynamicTypes.cpp`. ghstack-source-id: 147245110 Test Plan: Run the existing unit and integration tests. Reviewed By: albanD Differential Revision: D32676505 fbshipit-source-id: 3a3f6db4fb0da5c78dd406c96ab70bdc37015521 (cherry picked from commit d6427b94cf88b078bd228d43cd2afbabf0773b39)
This commit is contained in:
committed by
PyTorch MergeBot
parent
f5b19ba683
commit
80b19c4c8c
@ -522,5 +522,12 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
with self.assertRaisesRegex(RuntimeError, "but got None"):
|
||||
out.backward()
|
||||
|
||||
def test_storage_can_be_converted_to_python_object(self):
|
||||
with enable_python_mode(LoggingTensor):
|
||||
s = torch.Storage()
|
||||
z = LoggingTensor(torch.empty([]))
|
||||
z.set_(s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -921,7 +921,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
|
||||
return handle_torch_function(self, "storage");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return createPyObject(self_.storage(), self_.dtype());
|
||||
return createPyObject(self_.storage());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
@ -52,9 +52,7 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala
|
||||
return &at::getDeprecatedTypeProperties(backend, scalarType);
|
||||
}
|
||||
|
||||
PyTypeObject* getPyTypeObject(
|
||||
const at::Storage& storage,
|
||||
const caffe2::TypeMeta dtype) {
|
||||
PyTypeObject* getPyTypeObject(const at::Storage& storage) {
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/47442
|
||||
if (storage.device_type() == at::DeviceType::Meta) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
|
||||
@ -67,10 +65,9 @@ PyTypeObject* getPyTypeObject(
|
||||
at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())),
|
||||
scalarType);
|
||||
auto it = attype_to_py_storage_type.find(attype);
|
||||
if (it != attype_to_py_storage_type.end()) {
|
||||
return it->second;
|
||||
}
|
||||
throw std::invalid_argument("unsupported Storage type");
|
||||
TORCH_INTERNAL_ASSERT(it != attype_to_py_storage_type.end(),
|
||||
"Failed to get the Python type of `UntypedStorage`.");
|
||||
return it->second;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -106,10 +103,8 @@ THPLayout* getTHPLayout(at::Layout layout) {
|
||||
return thp_layout;
|
||||
}
|
||||
|
||||
PyObject* createPyObject(
|
||||
const at::Storage& storage,
|
||||
const caffe2::TypeMeta data_type) {
|
||||
auto type = getPyTypeObject(storage, data_type);
|
||||
PyObject* createPyObject(const at::Storage& storage) {
|
||||
auto type = getPyTypeObject(storage);
|
||||
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
|
||||
if (!obj) throw python_error();
|
||||
((THPVoidStorage*)obj.get())->cdata = at::Storage(/* copy */ storage).unsafeReleaseStorageImpl();
|
||||
|
@ -28,9 +28,7 @@ void registerStoragePyTypeObject(
|
||||
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
|
||||
void registerLayoutObject(THPLayout *thp_layout, at::Layout layout);
|
||||
|
||||
PyObject* createPyObject(
|
||||
const at::Storage& storage,
|
||||
const caffe2::TypeMeta data_type);
|
||||
PyObject* createPyObject(const at::Storage& storage);
|
||||
at::Storage createStorage(PyObject* obj);
|
||||
at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage);
|
||||
bool isStorage(PyObject* obj);
|
||||
|
@ -300,7 +300,7 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
|
||||
for (size_t i = 0, N = obj.storages_.size(); i < N; ++i) {
|
||||
py::object new_storage =
|
||||
py::reinterpret_steal<py::object>(torch::createPyObject(
|
||||
obj.storages_[i], scalarTypeToTypeMeta(obj.types_[i])));
|
||||
obj.storages_[i]));
|
||||
storages[i] = std::move(new_storage);
|
||||
}
|
||||
py::tuple dtypes(obj.types_.size());
|
||||
|
@ -41,6 +41,8 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
guardAgainstNamedTensor<autograd::Variable>(var);
|
||||
return var;
|
||||
}
|
||||
case TypeKind::StorageType:
|
||||
return py::cast<at::Storage>(obj);
|
||||
case TypeKind::FloatType:
|
||||
return py::cast<double>(obj);
|
||||
case TypeKind::ComplexType: {
|
||||
@ -331,7 +333,6 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
case TypeKind::DynamicType:
|
||||
case TypeKind::FunctionType:
|
||||
case TypeKind::GeneratorType:
|
||||
case TypeKind::StorageType:
|
||||
case TypeKind::QuantizerType:
|
||||
case TypeKind::VarType:
|
||||
case TypeKind::QSchemeType:
|
||||
|
@ -692,6 +692,8 @@ inline py::object toPyObject(IValue ivalue) {
|
||||
}
|
||||
guardAgainstNamedTensor<at::Tensor>(tensor);
|
||||
return py::cast(autograd::Variable(std::move(tensor)));
|
||||
} else if (ivalue.isStorage()) {
|
||||
return py::cast(ivalue.toStorage());
|
||||
} else if (ivalue.isDouble()) {
|
||||
return py::cast(std::move(ivalue).toDouble());
|
||||
} else if (ivalue.isComplexDouble()) {
|
||||
|
@ -69,11 +69,7 @@ struct type_caster<at::Storage> {
|
||||
|
||||
static handle
|
||||
cast(const at::Storage& src, return_value_policy /* policy */, handle /* parent */) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"NotImplementedError: pybind conversion of at::Storages from C++ to python not supported.");
|
||||
// Storages are untyped, see: https://github.com/pytorch/pytorch/issues/47442
|
||||
return handle(torch::createPyObject(src, caffe2::TypeMeta()));
|
||||
return handle(torch::createPyObject(src));
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user