Enable Python bindings for UntypedStorage (#68945)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68945

This PR enables the Python conversion functions for `Storage` (specifically `UntypedStorage`) and also cleans up some remnants of the deprecated typed storages from `DynamicTypes.cpp`.
ghstack-source-id: 147245110

Test Plan: Run the existing unit and integration tests.

Reviewed By: albanD

Differential Revision: D32676505

fbshipit-source-id: 3a3f6db4fb0da5c78dd406c96ab70bdc37015521
(cherry picked from commit d6427b94cf88b078bd228d43cd2afbabf0773b39)
This commit is contained in:
Can Balioglu
2022-01-19 18:07:36 -08:00
committed by PyTorch MergeBot
parent f5b19ba683
commit 80b19c4c8c
8 changed files with 21 additions and 22 deletions

View File

@ -522,5 +522,12 @@ $6 = torch._ops.aten.add_($1, $5)''')
with self.assertRaisesRegex(RuntimeError, "but got None"):
out.backward()
def test_storage_can_be_converted_to_python_object(self):
with enable_python_mode(LoggingTensor):
s = torch.Storage()
z = LoggingTensor(torch.empty([]))
z.set_(s)
if __name__ == '__main__':
run_tests()

View File

@ -921,7 +921,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
return handle_torch_function(self, "storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage(), self_.dtype());
return createPyObject(self_.storage());
END_HANDLE_TH_ERRORS
}

View File

@ -52,9 +52,7 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala
return &at::getDeprecatedTypeProperties(backend, scalarType);
}
PyTypeObject* getPyTypeObject(
const at::Storage& storage,
const caffe2::TypeMeta dtype) {
PyTypeObject* getPyTypeObject(const at::Storage& storage) {
// TODO: https://github.com/pytorch/pytorch/issues/47442
if (storage.device_type() == at::DeviceType::Meta) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
@ -67,10 +65,9 @@ PyTypeObject* getPyTypeObject(
at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())),
scalarType);
auto it = attype_to_py_storage_type.find(attype);
if (it != attype_to_py_storage_type.end()) {
return it->second;
}
throw std::invalid_argument("unsupported Storage type");
TORCH_INTERNAL_ASSERT(it != attype_to_py_storage_type.end(),
"Failed to get the Python type of `UntypedStorage`.");
return it->second;
}
} // namespace
@ -106,10 +103,8 @@ THPLayout* getTHPLayout(at::Layout layout) {
return thp_layout;
}
PyObject* createPyObject(
const at::Storage& storage,
const caffe2::TypeMeta data_type) {
auto type = getPyTypeObject(storage, data_type);
PyObject* createPyObject(const at::Storage& storage) {
auto type = getPyTypeObject(storage);
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
if (!obj) throw python_error();
((THPVoidStorage*)obj.get())->cdata = at::Storage(/* copy */ storage).unsafeReleaseStorageImpl();

View File

@ -28,9 +28,7 @@ void registerStoragePyTypeObject(
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
void registerLayoutObject(THPLayout *thp_layout, at::Layout layout);
PyObject* createPyObject(
const at::Storage& storage,
const caffe2::TypeMeta data_type);
PyObject* createPyObject(const at::Storage& storage);
at::Storage createStorage(PyObject* obj);
at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage);
bool isStorage(PyObject* obj);

View File

@ -300,7 +300,7 @@ struct __attribute__((visibility("hidden"))) ConcreteInterpreterSessionImpl
for (size_t i = 0, N = obj.storages_.size(); i < N; ++i) {
py::object new_storage =
py::reinterpret_steal<py::object>(torch::createPyObject(
obj.storages_[i], scalarTypeToTypeMeta(obj.types_[i])));
obj.storages_[i]));
storages[i] = std::move(new_storage);
}
py::tuple dtypes(obj.types_.size());

View File

@ -41,6 +41,8 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
guardAgainstNamedTensor<autograd::Variable>(var);
return var;
}
case TypeKind::StorageType:
return py::cast<at::Storage>(obj);
case TypeKind::FloatType:
return py::cast<double>(obj);
case TypeKind::ComplexType: {
@ -331,7 +333,6 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
case TypeKind::DynamicType:
case TypeKind::FunctionType:
case TypeKind::GeneratorType:
case TypeKind::StorageType:
case TypeKind::QuantizerType:
case TypeKind::VarType:
case TypeKind::QSchemeType:

View File

@ -692,6 +692,8 @@ inline py::object toPyObject(IValue ivalue) {
}
guardAgainstNamedTensor<at::Tensor>(tensor);
return py::cast(autograd::Variable(std::move(tensor)));
} else if (ivalue.isStorage()) {
return py::cast(ivalue.toStorage());
} else if (ivalue.isDouble()) {
return py::cast(std::move(ivalue).toDouble());
} else if (ivalue.isComplexDouble()) {

View File

@ -69,11 +69,7 @@ struct type_caster<at::Storage> {
static handle
cast(const at::Storage& src, return_value_policy /* policy */, handle /* parent */) {
TORCH_CHECK(
false,
"NotImplementedError: pybind conversion of at::Storages from C++ to python not supported.");
// Storages are untyped, see: https://github.com/pytorch/pytorch/issues/47442
return handle(torch::createPyObject(src, caffe2::TypeMeta()));
return handle(torch::createPyObject(src));
}
};