Rename _Typed/_UntypedStorage to Typed/UntypedStorage and update docs (#82438)

### Description

Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.

`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.

Documentation for storages is improved as well.

### Issue
Fixes #82436

### Testing
N/A

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-07-30 19:37:08 +00:00
committed by PyTorch MergeBot
parent 28304dd494
commit 14d0296e5c
26 changed files with 210 additions and 202 deletions

View File

@ -1,28 +1,33 @@
torch.Storage
===================================
A :class:`torch._TypedStorage` is a contiguous, one-dimensional array of
elements of a particular :class:`torch.dtype`. It can be given any
:class:`torch.dtype`, and the internal data will be interpretted appropriately.
Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.
For backward compatibility, there are also :class:`torch.<type>Storage` classes
(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These
classes are not actually instantiated, and calling their constructors creates
a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`.
:class:`torch.<type>Storage` classes have all of the same class methods that
:class:`torch._TypedStorage` has.
Also for backward compatibility, :class:`torch.Storage` is an alias for the
storage class that corresponds with the default data type
(:func:`torch.get_default_dtype()`). For instance, if the default data type is
:attr:`torch.float`, :class:`torch.Storage` resolves to
:class:`torch.Storage` is an alias for the storage class that corresponds with
the default data type (:func:`torch.get_default_dtype()`). For instance, if the
default data type is :attr:`torch.float`, :class:`torch.Storage` resolves to
:class:`torch.FloatStorage`.
The :class:`torch.<type>Storage` and :class:`torch.cuda.<type>Storage` classes,
like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc., are not
actually ever instantiated. Calling their constructors creates
a :class:`torch.TypedStorage` with the appropriate :class:`torch.dtype` and
:class:`torch.device`. :class:`torch.<type>Storage` classes have all of the
same class methods that :class:`torch.TypedStorage` has.
.. autoclass:: torch._TypedStorage
A :class:`torch.TypedStorage` is a contiguous, one-dimensional array of
elements of a particular :class:`torch.dtype`. It can be given any
:class:`torch.dtype`, and the internal data will be interpretted appropriately.
:class:`torch.TypedStorage` contains a :class:`torch.UntypedStorage` which
holds the data as an untyped array of bytes.
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.
.. autoclass:: torch.TypedStorage
:members:
:undoc-members:
:inherited-members:
.. autoclass:: torch.UntypedStorage
:members:
:undoc-members:
:inherited-members:

View File

@ -1852,7 +1852,7 @@
"QUInt4x2Storage",
"QUInt8Storage",
"Storage",
"_TypedStorage",
"TypedStorage",
"_adaptive_avg_pool2d",
"_adaptive_avg_pool3d",
"_add_batch_dim",

View File

@ -569,8 +569,8 @@ class TestCuda(TestCase):
self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

View File

@ -97,7 +97,7 @@ class SerializationMixin(object):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertEqual(c[4].dtype, torch.float)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
@ -370,7 +370,7 @@ class SerializationMixin(object):
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
self.assertEqual(c[4].dtype, torch.float32)
c[0].fill_(10)
self.assertEqual(c[0], c[2], atol=0, rtol=0)
@ -621,8 +621,8 @@ class SerializationMixin(object):
a = torch.tensor([], dtype=dtype, device=device)
for other_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
s = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
s = torch.TypedStorage(
wrap_storage=a.storage().untyped(),
dtype=other_dtype)
save_load_check(a, s)
save_load_check(a.storage(), s)
@ -653,8 +653,8 @@ class SerializationMixin(object):
torch.save([a.storage(), a.imag.storage()], f)
a = torch.randn(10, device=device)
s_bytes = torch._TypedStorage(
wrap_storage=a.storage()._untyped(),
s_bytes = torch.TypedStorage(
wrap_storage=a.storage().untyped(),
dtype=torch.uint8)
with self.assertRaisesRegex(RuntimeError, error_msg):

View File

@ -157,7 +157,7 @@ class TestTorchDeviceType(TestCase):
for i in range(10):
bytes_list = [rand_byte() for _ in range(element_size)]
scalar = bytes_to_scalar(bytes_list, dtype, device)
self.assertEqual(scalar.storage()._untyped().tolist(), bytes_list)
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
torch.bool, torch.float32, torch.complex64, torch.float64,
@ -175,7 +175,7 @@ class TestTorchDeviceType(TestCase):
v_s[el_num],
v[dim0][dim1])
v_s_byte = v.storage()._untyped()
v_s_byte = v.storage().untyped()
el_size = v.element_size()
for el_num in range(v.numel()):
@ -238,7 +238,7 @@ class TestTorchDeviceType(TestCase):
a_s = a.storage()
b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
self.assertEqual(a, b)
c = torch.tensor(a_s._untyped(), device=device, dtype=dtype).reshape(a.size())
c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size())
self.assertEqual(a, c)
for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
@ -255,7 +255,7 @@ class TestTorchDeviceType(TestCase):
a_s = a.storage()
b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
self.assertEqual(a, b)
c = torch.tensor([], device=device, dtype=dtype).set_(a_s._untyped()).reshape(a.size())
c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size())
self.assertEqual(a, c)
for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
@ -267,11 +267,11 @@ class TestTorchDeviceType(TestCase):
def _check_storage_meta(self, s, s_check):
self.assertTrue(
isinstance(s, (torch._UntypedStorage, torch._TypedStorage)) and
isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and
isinstance(s_check, type(s)),
(
's and s_check must both be one of _UntypedStorage or '
'_TypedStorage, but got'
's and s_check must both be one of UntypedStorage or '
'TypedStorage, but got'
f' {type(s).__name__} and {type(s_check).__name__}'))
self.assertEqual(s.device.type, 'meta')
@ -282,9 +282,9 @@ class TestTorchDeviceType(TestCase):
with self.assertRaisesRegex(NotImplementedError, r'Not available'):
s[0]
if isinstance(s, torch._TypedStorage):
if isinstance(s, torch.TypedStorage):
self.assertEqual(s.dtype, s_check.dtype)
self._check_storage_meta(s._untyped(), s_check._untyped())
self._check_storage_meta(s.untyped(), s_check.untyped())
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@ -296,8 +296,8 @@ class TestTorchDeviceType(TestCase):
[[1, 2, 3, 4, 5, 6]],
]
for args in args_list:
s_check = torch._TypedStorage(*args, dtype=dtype, device=device)
s = torch._TypedStorage(*args, dtype=dtype, device='meta')
s_check = torch.TypedStorage(*args, dtype=dtype, device=device)
s = torch.TypedStorage(*args, dtype=dtype, device='meta')
self._check_storage_meta(s, s_check)
@onlyNativeDeviceTypes
@ -309,8 +309,8 @@ class TestTorchDeviceType(TestCase):
[[1, 2, 3, 4, 5, 6]],
]
for args in args_list:
s_check = torch._UntypedStorage(*args, device=device)
s = torch._UntypedStorage(*args, device='meta')
s_check = torch.UntypedStorage(*args, device=device)
s = torch.UntypedStorage(*args, device='meta')
self._check_storage_meta(s, s_check)
@onlyNativeDeviceTypes
@ -326,7 +326,7 @@ class TestTorchDeviceType(TestCase):
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_storage_meta_errors(self, device, dtype):
s0 = torch._TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s0.cpu()
@ -361,7 +361,7 @@ class TestTorchDeviceType(TestCase):
s0._write_file(f, True, True, s0.element_size())
for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
s1 = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
s1.copy_(s0)
@ -6444,7 +6444,7 @@ class TestTorch(TestCase):
torch.storage._LegacyStorage()
for storage_class in torch._storage_classes:
if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
if storage_class in [torch.UntypedStorage, torch.TypedStorage]:
continue
device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
@ -6475,9 +6475,9 @@ class TestTorch(TestCase):
s = storage_class()
with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
storage_class(0, wrap_storage=s._untyped())
storage_class(0, wrap_storage=s.untyped())
with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"):
storage_class(wrap_storage=s)
if torch.cuda.is_available():
@ -6493,40 +6493,40 @@ class TestTorch(TestCase):
s_other_device = s.cuda()
with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
storage_class(wrap_storage=s_other_device._untyped())
storage_class(wrap_storage=s_other_device.untyped())
# _TypedStorage constructor errors
# TypedStorage constructor errors
with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
torch._TypedStorage(wrap_storage=s._untyped())
torch.TypedStorage(wrap_storage=s.untyped())
with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
torch.TypedStorage(wrap_storage=s.untyped(), dtype=0)
with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device)
with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
torch._TypedStorage(wrap_storage=s, dtype=dtype)
with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"):
torch.TypedStorage(wrap_storage=s, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
torch._TypedStorage(dtype=dtype, device='xla')
torch.TypedStorage(dtype=dtype, device='xla')
if torch.cuda.is_available():
if storage_class in quantized_storages:
with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
torch._TypedStorage(dtype=dtype, device='cuda')
torch.TypedStorage(dtype=dtype, device='cuda')
with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
torch._TypedStorage(0, 0, dtype=dtype, device=device)
torch.TypedStorage(0, 0, dtype=dtype, device=device)
if isinstance(s, torch._TypedStorage):
s_other = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
if isinstance(s, torch.TypedStorage):
s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
s.fill_(s_other)

View File

@ -1146,7 +1146,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an _UntypedStorage, but got type ", storage_scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor {
pybind11::gil_scoped_release no_gil;
@ -1162,7 +1162,7 @@ static PyObject* THPVariable_set_(
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
"Expected a Storage of type ", self.dtype(),
" or an _UntypedStorage, but got type ", storage_scalar_type,
" or an UntypedStorage, but got type ", storage_scalar_type,
" for argument 1 'storage'");
auto dispatch_set_ = [](const Tensor& self,
Storage source,

View File

@ -696,8 +696,8 @@ def gen_pyi(
"def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
],
"set_": [
"def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
"def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...",
],
"split": [
"def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",

View File

@ -13,7 +13,7 @@ from typing_extensions import Literal
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt
from torch.storage import _TypedStorage
from torch.storage import TypedStorage
import builtins

View File

@ -40,7 +40,7 @@ __all__ = [
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
'_TypedStorage',
'TypedStorage', 'UntypedStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
'lobpcg', 'use_deterministic_algorithms',
@ -656,10 +656,10 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
################################################################################
from ._tensor import Tensor
from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage._TypedStorage directly.
# dtype, use torch.storage.TypedStorage directly.
class ByteStorage(_LegacyStorage):
@classproperty
@ -747,11 +747,11 @@ class QUInt2x4Storage(_LegacyStorage):
return torch.quint2x4
_storage_classes = {
_UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
_TypedStorage
TypedStorage
}
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()

View File

@ -19,8 +19,8 @@ def _save_storages(importer, obj):
importers = sys_importer
def persistent_id(obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
@ -66,11 +66,11 @@ def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dt
if typename == "storage":
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
# stop wrapping with TypedStorage
storage = serialized_storages[data[0]]
dtype = serialized_dtypes[data[0]]
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
)
if typename == "reduce_deploy":

View File

@ -11,7 +11,7 @@ import torch
import torch._prims_common as utils
import torch.library
from torch import _TypedStorage, Tensor
from torch import Tensor, TypedStorage
from torch._C import _get_default_device
from torch._prims_common import (
check,

View File

@ -1303,7 +1303,7 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
def check_in_bounds_for_storage(
a: torch._TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
):
"""
Determines if the given shape, strides, and offset are valid for the given storage.

View File

@ -158,10 +158,10 @@ class Tensor(torch._C._TensorBase):
f"Unsupported qscheme {self.qscheme()} in deepcopy"
)
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with _TypedStorage
# need to wrap with TypedStorage
new_tensor = torch._utils._rebuild_qtensor(
torch.storage._TypedStorage(
wrap_storage=new_storage._untyped(), dtype=self.dtype
torch.storage.TypedStorage(
wrap_storage=new_storage.untyped(), dtype=self.dtype
),
self.storage_offset(),
self.size(),
@ -255,7 +255,7 @@ class Tensor(torch._C._TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
def _reduce_ex_internal(self, proto):
check_serializing_named_tensor(self)
@ -324,10 +324,10 @@ class Tensor(torch._C._TensorBase):
f"Serialization is not supported for tensors of type {self.qscheme()}"
)
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with _TypedStorage
# need to wrap with TypedStorage
args_qtensor = (
torch.storage._TypedStorage(
wrap_storage=self.storage()._untyped(), dtype=self.dtype
torch.storage.TypedStorage(
wrap_storage=self.storage().untyped(), dtype=self.dtype
),
self.storage_offset(),
tuple(self.size()),
@ -382,10 +382,10 @@ class Tensor(torch._C._TensorBase):
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
else:
# TODO: Once we decide to break serialization FC, no longer
# need to wrap with _TypedStorage
# need to wrap with TypedStorage
args = (
torch.storage._TypedStorage(
wrap_storage=self.storage()._untyped(), dtype=self.dtype
torch.storage.TypedStorage(
wrap_storage=self.storage().untyped(), dtype=self.dtype
),
self.storage_offset(),
tuple(self.size()),

View File

@ -77,9 +77,11 @@ def _cuda(self, device=None, non_blocking=False, **kwargs):
values = torch.Tensor._values(self).cuda(device, non_blocking)
return new_type(indices, values, self.size())
else:
return torch._UntypedStorage(
untyped_storage = torch.UntypedStorage(
self.size(), device=torch.device("cuda")
).copy_(self, non_blocking)
)
untyped_storage.copy_(self, non_blocking)
return untyped_storage
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
@ -138,11 +140,11 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a _TypedStorage
# be a TypedStorage
def _rebuild_tensor(storage, storage_offset, size, stride):
# first construct a tensor with the correct dtype/device
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device)
return t.set_(storage._untyped(), storage_offset, size, stride)
t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
return t.set_(storage.untyped(), storage_offset, size, stride)
def _rebuild_tensor_v2(
@ -251,7 +253,7 @@ def _rebuild_wrapper_subclass(
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
# be a _TypedStorage
# be a TypedStorage
def _rebuild_qtensor(
storage,
storage_offset,

View File

@ -94,10 +94,10 @@ PyTypeObject* loadTypedStorageTypeObject() {
TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
PyObject* typed_storage_obj =
PyObject_GetAttrString(storage_module, "_TypedStorage");
PyObject_GetAttrString(storage_module, "TypedStorage");
TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
return reinterpret_cast<PyTypeObject*>(
PyObject_GetAttrString(storage_module, "_TypedStorage"));
PyObject_GetAttrString(storage_module, "TypedStorage"));
}
PyTypeObject* getTypedStorageTypeObject() {
@ -125,7 +125,7 @@ at::Storage createStorageGetType(
if (is_typed_storage) {
// NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
// `_storage`, so we must decrement them. The refcounts will still stay
// nonzero since the `_TypedStorage` maintains a reference.
// nonzero since the `TypedStorage` maintains a reference.
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
TORCH_INTERNAL_ASSERT(dtype_obj);
Py_DECREF(dtype_obj);

View File

@ -387,7 +387,7 @@ bool THPStorage_init(PyObject* module) {
}
void THPStorage_postInit(PyObject* module) {
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
if (!THPStorageClass)
throw python_error();
}

View File

@ -3,7 +3,7 @@
#include <torch/csrc/Types.h>
#define THPStorageStr "torch._UntypedStorage"
#define THPStorageStr "torch.UntypedStorage"
#define THPStorageBaseStr "StorageBase"
struct THPStorage {

View File

@ -138,7 +138,7 @@ static PyObject* THPStorage_resize_(PyObject* _self, PyObject* number_arg) {
} else {
TORCH_CHECK(
false,
"_UntypedStorage.resize_: got unexpected device type ",
"UntypedStorage.resize_: got unexpected device type ",
device_type);
}
Py_INCREF(self);

View File

@ -361,7 +361,7 @@ Tensor internal_new_from_data(
!is_typed_storage || storage_scalar_type == scalar_type,
"Expected a Storage of type ",
scalar_type,
" or an _UntypedStorage, but got ",
" or an UntypedStorage, but got ",
storage_scalar_type);
tensor = at::empty(
sizes,
@ -642,7 +642,7 @@ Tensor legacy_tensor_generic_ctor_new(
storage_scalar_type == scalar_type,
"Expected a Storage of type ",
scalar_type,
" or an _UntypedStorage, but got type ",
" or an UntypedStorage, but got type ",
storage_scalar_type,
" for argument 1 'storage'");
}

View File

@ -133,7 +133,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
t = torch._utils._rebuild_tensor(
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype),
tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
@ -298,7 +298,7 @@ def storage_from_cache(cls, key):
storage_ref = shared_cache.get(key)
if storage_ref is None:
return None
return torch._UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
def rebuild_storage_fd(cls, df, size):
@ -315,15 +315,15 @@ def rebuild_storage_fd(cls, df, size):
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle)
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(cls, handle)
if storage is not None:
return storage._shared_decref()
if dtype is None:
storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size)
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
else:
byte_size = size * torch._utils._element_size(dtype)
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
storage = torch._TypedStorage(
untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
storage = torch.TypedStorage(
wrap_storage=untyped_storage,
dtype=dtype)
shared_cache[handle] = StorageWeakRef(storage)
@ -334,16 +334,16 @@ def rebuild_storage_empty(cls):
return cls()
def rebuild_typed_storage(storage, dtype):
return torch.storage._TypedStorage(wrap_storage=storage, dtype=dtype)
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype)
# Use for torch.storage._TypedStorage
# Use for torch.storage.TypedStorage
def reduce_typed_storage(storage):
return (rebuild_typed_storage, (storage._storage, storage.dtype))
def rebuild_typed_storage_child(storage, storage_type):
return storage_type(wrap_storage=storage)
# Use for child classes of torch.storage._TypedStorage, like torch.FloatStorage
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
def reduce_typed_storage_child(storage):
return (rebuild_typed_storage_child, (storage._storage, type(storage)))
@ -355,7 +355,7 @@ def reduce_storage(storage):
metadata = storage._share_filename_cpu_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
if isinstance(storage, torch._TypedStorage):
if isinstance(storage, torch.TypedStorage):
metadata += (storage.dtype,)
storage._shared_incref()
elif storage.size() == 0:
@ -377,12 +377,12 @@ def init_reductions():
ForkingPickler.register(torch.cuda.Event, reduce_event)
for t in torch._storage_classes:
if t.__name__ == '_UntypedStorage':
if t.__name__ == 'UntypedStorage':
ForkingPickler.register(t, reduce_storage)
else:
ForkingPickler.register(t, reduce_typed_storage_child)
ForkingPickler.register(torch.storage._TypedStorage, reduce_typed_storage)
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
for t in torch._tensor_classes:
ForkingPickler.register(t, reduce_tensor)

View File

@ -35,7 +35,7 @@ class DirectoryReader(object):
def get_storage_from_record(self, name, numel, dtype):
filename = f"{self.directory}/{name}"
nbytes = torch._utils._element_size(dtype) * numel
storage = cast(Storage, torch._UntypedStorage)
storage = cast(Storage, torch.UntypedStorage)
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
def has_record(self, path):

View File

@ -883,8 +883,8 @@ class PackageExporter:
)
def _persistent_id(self, obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage._TypedStorage):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
untyped_storage = obj._storage
@ -892,7 +892,7 @@ class PackageExporter:
storage_type = getattr(torch, storage_type_str)
storage_numel = obj.size()
elif isinstance(obj, torch._UntypedStorage):
elif isinstance(obj, torch.UntypedStorage):
untyped_storage = obj
storage_type = normalize_storage_type(type(storage))
storage_numel = storage.nbytes()

View File

@ -234,9 +234,9 @@ class PackageImporter(Importer):
)
storage = loaded_storages[key]
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
return torch.storage._TypedStorage(
wrap_storage=storage._untyped(), dtype=dtype
# stop wrapping with TypedStorage
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
)
elif typename == "reduce_package":
# to fix BC breaking change, objects on this load path

View File

@ -157,7 +157,7 @@ def _cuda_deserialize(obj, location):
device = validate_cuda_device(location)
if getattr(obj, "_torch_load_uninitialized", False):
with torch.cuda.device(device):
return torch._UntypedStorage(obj.nbytes(), device=torch.device(location))
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
else:
return obj.cuda(device)
@ -171,7 +171,7 @@ register_package(20, _cuda_tag, _cuda_deserialize)
register_package(21, _mps_tag, _mps_deserialize)
def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@ -423,10 +423,10 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
"for correctness upon loading.")
return ('module', obj, source_file, source)
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
storage: torch._UntypedStorage
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
storage: torch.UntypedStorage
if isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
@ -436,7 +436,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
dtype = obj.dtype
storage_numel = obj.size()
elif isinstance(obj, torch._UntypedStorage):
elif isinstance(obj, torch.UntypedStorage):
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
@ -476,8 +476,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
# effectively saving nbytes in this case. We'll be able to load it
# and the tensor back up with no problems in _this_ and future
# versions of pytorch, but in older versions, here's the problem:
# the storage will be loaded up as a _UntypedStorage, and then the
# FloatTensor will loaded and the _UntypedStorage will be assigned to
# the storage will be loaded up as a UntypedStorage, and then the
# FloatTensor will loaded and the UntypedStorage will be assigned to
# it. Since the storage dtype does not match the tensor dtype, this
# will cause an error. If we reverse the list, like `[tensor,
# storage]`, then we will save the `tensor.storage()` as a faked
@ -485,7 +485,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
# dtype-specific numel count that old versions expect. `tensor`
# will be able to load up properly in old versions, pointing to
# a FloatStorage. However, `storage` is still being translated to
# a _UntypedStorage, and it will try to resolve to the same
# a UntypedStorage, and it will try to resolve to the same
# FloatStorage that `tensor` contains. This will also cause an
# error. It doesn't seem like there's any way around this.
# Probably, we just cannot maintain FC for the legacy format if the
@ -552,9 +552,9 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage._TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
@ -817,11 +817,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type.dtype
obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
deserialized_objects[key] = torch.storage._TypedStorage(
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
wrap_storage=obj,
dtype=dtype)
@ -831,8 +831,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
element_size = torch._utils._element_size(root.dtype)
offset_bytes = offset * element_size
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
deserialized_objects[target_cdata] = torch.storage._TypedStorage(
# stop wrapping with TypedStorage
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype)
@ -879,11 +879,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
nbytes = numel * torch._utils._element_size(dtype)
if root_key not in deserialized_objects:
obj = cast(Storage, torch._UntypedStorage(nbytes))
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
deserialized_objects[root_key] = torch.storage._TypedStorage(
# stop wrapping with TypedStorage
deserialized_objects[root_key] = torch.storage.TypedStorage(
wrap_storage=restore_location(obj, location),
dtype=dtype)
@ -894,8 +894,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
view_size_bytes = view_size * torch._utils._element_size(dtype)
if view_key not in deserialized_objects:
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
deserialized_objects[view_key] = torch.storage._TypedStorage(
# stop wrapping with TypedStorage
deserialized_objects[view_key] = torch.storage.TypedStorage(
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype)
res = deserialized_objects[view_key]
@ -1005,10 +1005,10 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped()
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with _TypedStorage
loaded_storages[key] = torch.storage._TypedStorage(
# stop wrapping with TypedStorage
loaded_storages[key] = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location),
dtype=dtype)
@ -1020,7 +1020,7 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
if storage_type is torch._UntypedStorage:
if storage_type is torch.UntypedStorage:
dtype = torch.uint8
else:
dtype = storage_type.dtype

View File

@ -13,7 +13,7 @@ try:
except ModuleNotFoundError:
np = None # type: ignore[assignment]
T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
class _StorageBase(object):
_cdata: Any
is_sparse: bool = False
@ -117,14 +117,14 @@ class _StorageBase(object):
def cpu(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
if self.device.type != 'cpu':
return torch._UntypedStorage(self.size()).copy_(self, False)
return torch.UntypedStorage(self.size()).copy_(self, False)
else:
return self
def mps(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
if self.device.type != 'mps':
return torch._UntypedStorage(self.size(), device="mps").copy_(self, False)
return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
else:
return self
@ -222,11 +222,11 @@ class _StorageBase(object):
else:
return cls._new_using_fd_cpu(size)
def _untyped(self):
def untyped(self):
return self
class _UntypedStorage(torch._C.StorageBase, _StorageBase):
class UntypedStorage(torch._C.StorageBase, _StorageBase):
def __getitem__(self, *args, **kwargs):
if self.device.type == 'meta':
raise NotImplementedError("Not available for 'meta' device type")
@ -248,9 +248,9 @@ _StorageBase.cuda = _cuda # type: ignore[assignment]
def _dtype_to_storage_type_map():
# NOTE: We should no longer add dtypes to this map. This map
# is only used for BC/FC with older PyTorch versions. Going forward,
# new dtypes of _TypedStorage should not translate to a legacy
# <type>Storage class. Instead, new dtypes of _TypedStorage should
# be serialized as an _UntypedStorage paired with a torch.dtype
# new dtypes of TypedStorage should not translate to a legacy
# <type>Storage class. Instead, new dtypes of TypedStorage should
# be serialized as an UntypedStorage paired with a torch.dtype
return {
torch.double: 'DoubleStorage',
torch.float: 'FloatStorage',
@ -297,7 +297,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
dtype=dtype,
device=device)
return tmp_tensor.storage()._untyped()
return tmp_tensor.storage().untyped()
def _isint(x):
if HAS_NUMPY:
@ -305,7 +305,7 @@ def _isint(x):
else:
return isinstance(x, int)
class _TypedStorage:
class TypedStorage:
is_sparse = False
dtype: torch.dtype
@ -318,7 +318,7 @@ class _TypedStorage:
if cls == torch.storage._LegacyStorage:
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
if cls == _TypedStorage:
if cls == TypedStorage:
return super().__new__(cls)
else:
@ -328,7 +328,7 @@ class _TypedStorage:
' * no arguments\n'
' * (int size)\n'
' * (Sequence data)\n'
' * (*, _UntypedStorage wrap_storage)')
' * (*, UntypedStorage wrap_storage)')
if device is not None:
raise RuntimeError(
@ -351,7 +351,7 @@ class _TypedStorage:
arg_error_msg +
f"\nArgument type not recognized: {type(args[0])}")
return _TypedStorage(
return TypedStorage(
*args,
dtype=cls.dtype,
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu')
@ -363,10 +363,10 @@ class _TypedStorage:
"\nNo positional arguments should be given when using "
"'wrap_storage'")
if not isinstance(wrap_storage, torch._UntypedStorage):
if not isinstance(wrap_storage, torch.UntypedStorage):
raise TypeError(
arg_error_msg +
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
@ -376,19 +376,19 @@ class _TypedStorage:
f"\nDevice of 'wrap_storage' must be {cls_device}"
f", but got {wrap_storage.device.type}")
return _TypedStorage(
return TypedStorage(
*args,
wrap_storage=wrap_storage,
dtype=cls.dtype)
def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
arg_error_msg = (
'_TypedStorage.__init__ received an invalid combination '
'TypedStorage.__init__ received an invalid combination '
'of arguments. Expected one of:\n'
' * (*, torch.device device, torch.dtype dtype)\n'
' * (int size, *, torch.device device, torch.dtype dtype)\n'
' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
' * (*, UntypedStorage wrap_storage, torch.dtype dtype)')
if wrap_storage is not None:
if len(args) != 0:
@ -414,10 +414,10 @@ class _TypedStorage:
self.dtype = dtype
if not isinstance(wrap_storage, torch._UntypedStorage):
if not isinstance(wrap_storage, torch.UntypedStorage):
raise TypeError(
arg_error_msg +
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
self._storage = wrap_storage
@ -430,11 +430,11 @@ class _TypedStorage:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
if len(args) == 0:
self._storage = torch._UntypedStorage(device=device)
self._storage = torch.UntypedStorage(device=device)
elif len(args) == 1:
if _isint(args[0]):
self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device)
self._storage = torch.UntypedStorage(int(args[0]) * self.element_size(), device=device)
elif isinstance(args[0], collections.abc.Sequence):
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
else:
@ -452,14 +452,15 @@ class _TypedStorage:
def is_cuda(self):
return self.device.type == 'cuda'
def _untyped(self):
def untyped(self):
"""Returns the internal :class:`torch.UntypedStorage`"""
return self._storage
def _new_wrapped_storage(self, untyped_storage):
assert type(untyped_storage) == torch._UntypedStorage
assert type(untyped_storage) == torch.UntypedStorage
if type(self) == _TypedStorage:
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
if type(self) == TypedStorage:
return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
else:
return type(self)(wrap_storage=untyped_storage)
@ -505,7 +506,7 @@ class _TypedStorage:
torch.qint8: torch.int8
}
tmp_dtype = interpret_dtypes[self.dtype]
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage(
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
wrap_storage=self._storage,
dtype=tmp_dtype))
else:
@ -517,12 +518,12 @@ class _TypedStorage:
if self.device.type == 'meta':
raise NotImplementedError("Not available for 'meta' device type")
# NOTE: Before _TypedStorage existed, indexing with a slice used to be
# NOTE: Before TypedStorage existed, indexing with a slice used to be
# possible for <type>Storage objects. However, it would return
# a storage view, which would be a hassle to implement in _TypedStorage,
# a storage view, which would be a hassle to implement in TypedStorage,
# so it was disabled
if isinstance(idx, slice):
raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__')
raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
elif not isinstance(idx, int):
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
@ -534,7 +535,7 @@ class _TypedStorage:
torch.qint32: torch.int32,
torch.qint8: torch.int8
}
return _TypedStorage(
return TypedStorage(
wrap_storage=self._storage,
dtype=interpret_dtypes[self.dtype])[idx]
@ -543,7 +544,7 @@ class _TypedStorage:
return tmp_tensor[idx_wrapped].item()
def copy_(self, source: T, non_blocking: bool = None):
self._storage.copy_(source._untyped(), non_blocking)
self._storage.copy_(source.untyped(), non_blocking)
return self
def nbytes(self):
@ -564,7 +565,7 @@ class _TypedStorage:
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
def element_size(self):
@ -596,7 +597,7 @@ class _TypedStorage:
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
def __sizeof__(self):
return super(_TypedStorage, self).__sizeof__() + self.nbytes()
return super(TypedStorage, self).__sizeof__() + self.nbytes()
def clone(self):
"""Returns a copy of this storage"""
@ -631,8 +632,8 @@ class _TypedStorage:
if device is None:
device = 'cpu'
device = torch.device(device)
untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device)
return _TypedStorage(
untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device)
return TypedStorage(
wrap_storage=untyped_storage,
dtype=self.dtype)
@ -666,34 +667,34 @@ class _TypedStorage:
@classmethod
def _free_weak_ref(cls, *args, **kwargs):
return _UntypedStorage._free_weak_ref(*args, **kwargs)
return UntypedStorage._free_weak_ref(*args, **kwargs)
def _weak_ref(self, *args, **kwargs):
return self._storage._weak_ref(*args, **kwargs)
@classmethod
def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
if cls == _TypedStorage:
if cls == TypedStorage:
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device('cpu' if device is None else device)
if device.type != 'cpu':
raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}')
untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}')
untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
else:
if dtype is not None or len(args) == 5:
raise RuntimeError((
"from_buffer: 'dtype' can only be specified in "
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
if device is not None:
raise RuntimeError((
"from_buffer: 'device' can only be specified in "
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
dtype = cls.dtype
untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
return TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
@ -770,9 +771,9 @@ class _TypedStorage:
shared (bool): whether to share memory
size (int): number of elements in the storage
"""
if cls == _TypedStorage:
if cls == TypedStorage:
raise RuntimeError('from_file can only be called on derived classes')
untyped_storage: _UntypedStorage = _UntypedStorage.from_file(
untyped_storage: UntypedStorage = UntypedStorage.from_file(
filename,
shared,
size * torch._utils._element_size(cls.dtype))
@ -781,7 +782,7 @@ class _TypedStorage:
@classmethod
def _expired(cls, *args, **kwargs):
return _UntypedStorage._expired(*args, **kwargs)
return UntypedStorage._expired(*args, **kwargs)
def is_pinned(self):
return self._storage.is_pinned()
@ -803,7 +804,7 @@ class _TypedStorage:
@classmethod
def _new_shared_cuda(cls, *args, **kwargs):
return torch._UntypedStorage._new_shared_cuda(*args, **kwargs)
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
def _share_filename_cpu_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
@ -815,7 +816,7 @@ class _TypedStorage:
@classmethod
def _release_ipc_counter(cls, *args, device=None, **kwargs):
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
def _shared_incref(self, *args, **kwargs):
return self._storage._shared_incref(*args, **kwargs)
@ -840,33 +841,33 @@ class _TypedStorage:
except AttributeError:
return None
_TypedStorage.type.__doc__ = _type.__doc__
_TypedStorage.cuda.__doc__ = _cuda.__doc__
TypedStorage.type.__doc__ = _type.__doc__
TypedStorage.cuda.__doc__ = _cuda.__doc__
class _LegacyStorageMeta(type):
dtype: torch.dtype
def __instancecheck__(cls, instance):
if type(instance) == _TypedStorage:
if type(instance) == TypedStorage:
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
return False
class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
@classmethod
def _new_shared(cls, size):
"""Creates a new storage in shared memory with the same data type"""
untyped_storage = torch._UntypedStorage._new_shared(size * cls().element_size())
untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size())
return cls(wrap_storage=untyped_storage)
@classmethod
def _release_ipc_counter(cls, *args, **kwargs):
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
@classmethod
def _new_shared_filename(cls, manager, obj, size):
bytes_size = size * torch._utils._element_size(cls.dtype)
return cls(wrap_storage=torch._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
try:

View File

@ -2369,7 +2369,7 @@ class TestCase(expecttest.TestCase):
),
sequence_types=(
Sequence,
torch.storage._TypedStorage,
torch.storage.TypedStorage,
Sequential,
ModuleList,
ParameterList,