mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename _Typed/_UntypedStorage to Typed/UntypedStorage and update docs (#82438)
### Description Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public. `TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`. Documentation for storages is improved as well. ### Issue Fixes #82436 ### Testing N/A Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
28304dd494
commit
14d0296e5c
@ -1,28 +1,33 @@
|
||||
torch.Storage
|
||||
===================================
|
||||
|
||||
A :class:`torch._TypedStorage` is a contiguous, one-dimensional array of
|
||||
elements of a particular :class:`torch.dtype`. It can be given any
|
||||
:class:`torch.dtype`, and the internal data will be interpretted appropriately.
|
||||
|
||||
Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`,
|
||||
which stores all of the data that the :class:`torch.Tensor` views.
|
||||
|
||||
For backward compatibility, there are also :class:`torch.<type>Storage` classes
|
||||
(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These
|
||||
classes are not actually instantiated, and calling their constructors creates
|
||||
a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`.
|
||||
:class:`torch.<type>Storage` classes have all of the same class methods that
|
||||
:class:`torch._TypedStorage` has.
|
||||
|
||||
Also for backward compatibility, :class:`torch.Storage` is an alias for the
|
||||
storage class that corresponds with the default data type
|
||||
(:func:`torch.get_default_dtype()`). For instance, if the default data type is
|
||||
:attr:`torch.float`, :class:`torch.Storage` resolves to
|
||||
:class:`torch.Storage` is an alias for the storage class that corresponds with
|
||||
the default data type (:func:`torch.get_default_dtype()`). For instance, if the
|
||||
default data type is :attr:`torch.float`, :class:`torch.Storage` resolves to
|
||||
:class:`torch.FloatStorage`.
|
||||
|
||||
The :class:`torch.<type>Storage` and :class:`torch.cuda.<type>Storage` classes,
|
||||
like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc., are not
|
||||
actually ever instantiated. Calling their constructors creates
|
||||
a :class:`torch.TypedStorage` with the appropriate :class:`torch.dtype` and
|
||||
:class:`torch.device`. :class:`torch.<type>Storage` classes have all of the
|
||||
same class methods that :class:`torch.TypedStorage` has.
|
||||
|
||||
.. autoclass:: torch._TypedStorage
|
||||
A :class:`torch.TypedStorage` is a contiguous, one-dimensional array of
|
||||
elements of a particular :class:`torch.dtype`. It can be given any
|
||||
:class:`torch.dtype`, and the internal data will be interpretted appropriately.
|
||||
:class:`torch.TypedStorage` contains a :class:`torch.UntypedStorage` which
|
||||
holds the data as an untyped array of bytes.
|
||||
|
||||
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
|
||||
which stores all of the data that the :class:`torch.Tensor` views.
|
||||
|
||||
.. autoclass:: torch.TypedStorage
|
||||
:members:
|
||||
:undoc-members:
|
||||
:inherited-members:
|
||||
|
||||
.. autoclass:: torch.UntypedStorage
|
||||
:members:
|
||||
:undoc-members:
|
||||
:inherited-members:
|
||||
|
||||
@ -1852,7 +1852,7 @@
|
||||
"QUInt4x2Storage",
|
||||
"QUInt8Storage",
|
||||
"Storage",
|
||||
"_TypedStorage",
|
||||
"TypedStorage",
|
||||
"_adaptive_avg_pool2d",
|
||||
"_adaptive_avg_pool3d",
|
||||
"_add_batch_dim",
|
||||
|
||||
@ -569,8 +569,8 @@ class TestCuda(TestCase):
|
||||
self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor))
|
||||
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
|
||||
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
|
||||
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
|
||||
q_copy[1].fill_(10)
|
||||
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ class SerializationMixin(object):
|
||||
self.assertTrue(isinstance(c[1], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[2], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[3], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
|
||||
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
|
||||
self.assertEqual(c[4].dtype, torch.float)
|
||||
c[0].fill_(10)
|
||||
self.assertEqual(c[0], c[2], atol=0, rtol=0)
|
||||
@ -370,7 +370,7 @@ class SerializationMixin(object):
|
||||
self.assertTrue(isinstance(c[1], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[2], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[3], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[4], torch.storage._TypedStorage))
|
||||
self.assertTrue(isinstance(c[4], torch.storage.TypedStorage))
|
||||
self.assertEqual(c[4].dtype, torch.float32)
|
||||
c[0].fill_(10)
|
||||
self.assertEqual(c[0], c[2], atol=0, rtol=0)
|
||||
@ -621,8 +621,8 @@ class SerializationMixin(object):
|
||||
a = torch.tensor([], dtype=dtype, device=device)
|
||||
|
||||
for other_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
s = torch._TypedStorage(
|
||||
wrap_storage=a.storage()._untyped(),
|
||||
s = torch.TypedStorage(
|
||||
wrap_storage=a.storage().untyped(),
|
||||
dtype=other_dtype)
|
||||
save_load_check(a, s)
|
||||
save_load_check(a.storage(), s)
|
||||
@ -653,8 +653,8 @@ class SerializationMixin(object):
|
||||
torch.save([a.storage(), a.imag.storage()], f)
|
||||
|
||||
a = torch.randn(10, device=device)
|
||||
s_bytes = torch._TypedStorage(
|
||||
wrap_storage=a.storage()._untyped(),
|
||||
s_bytes = torch.TypedStorage(
|
||||
wrap_storage=a.storage().untyped(),
|
||||
dtype=torch.uint8)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
|
||||
@ -157,7 +157,7 @@ class TestTorchDeviceType(TestCase):
|
||||
for i in range(10):
|
||||
bytes_list = [rand_byte() for _ in range(element_size)]
|
||||
scalar = bytes_to_scalar(bytes_list, dtype, device)
|
||||
self.assertEqual(scalar.storage()._untyped().tolist(), bytes_list)
|
||||
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
|
||||
|
||||
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
|
||||
torch.bool, torch.float32, torch.complex64, torch.float64,
|
||||
@ -175,7 +175,7 @@ class TestTorchDeviceType(TestCase):
|
||||
v_s[el_num],
|
||||
v[dim0][dim1])
|
||||
|
||||
v_s_byte = v.storage()._untyped()
|
||||
v_s_byte = v.storage().untyped()
|
||||
el_size = v.element_size()
|
||||
|
||||
for el_num in range(v.numel()):
|
||||
@ -238,7 +238,7 @@ class TestTorchDeviceType(TestCase):
|
||||
a_s = a.storage()
|
||||
b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size())
|
||||
self.assertEqual(a, b)
|
||||
c = torch.tensor(a_s._untyped(), device=device, dtype=dtype).reshape(a.size())
|
||||
c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size())
|
||||
self.assertEqual(a, c)
|
||||
|
||||
for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
|
||||
@ -255,7 +255,7 @@ class TestTorchDeviceType(TestCase):
|
||||
a_s = a.storage()
|
||||
b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size())
|
||||
self.assertEqual(a, b)
|
||||
c = torch.tensor([], device=device, dtype=dtype).set_(a_s._untyped()).reshape(a.size())
|
||||
c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size())
|
||||
self.assertEqual(a, c)
|
||||
|
||||
for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
|
||||
@ -267,11 +267,11 @@ class TestTorchDeviceType(TestCase):
|
||||
|
||||
def _check_storage_meta(self, s, s_check):
|
||||
self.assertTrue(
|
||||
isinstance(s, (torch._UntypedStorage, torch._TypedStorage)) and
|
||||
isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and
|
||||
isinstance(s_check, type(s)),
|
||||
(
|
||||
's and s_check must both be one of _UntypedStorage or '
|
||||
'_TypedStorage, but got'
|
||||
's and s_check must both be one of UntypedStorage or '
|
||||
'TypedStorage, but got'
|
||||
f' {type(s).__name__} and {type(s_check).__name__}'))
|
||||
|
||||
self.assertEqual(s.device.type, 'meta')
|
||||
@ -282,9 +282,9 @@ class TestTorchDeviceType(TestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, r'Not available'):
|
||||
s[0]
|
||||
|
||||
if isinstance(s, torch._TypedStorage):
|
||||
if isinstance(s, torch.TypedStorage):
|
||||
self.assertEqual(s.dtype, s_check.dtype)
|
||||
self._check_storage_meta(s._untyped(), s_check._untyped())
|
||||
self._check_storage_meta(s.untyped(), s_check.untyped())
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@ -296,8 +296,8 @@ class TestTorchDeviceType(TestCase):
|
||||
[[1, 2, 3, 4, 5, 6]],
|
||||
]
|
||||
for args in args_list:
|
||||
s_check = torch._TypedStorage(*args, dtype=dtype, device=device)
|
||||
s = torch._TypedStorage(*args, dtype=dtype, device='meta')
|
||||
s_check = torch.TypedStorage(*args, dtype=dtype, device=device)
|
||||
s = torch.TypedStorage(*args, dtype=dtype, device='meta')
|
||||
self._check_storage_meta(s, s_check)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@ -309,8 +309,8 @@ class TestTorchDeviceType(TestCase):
|
||||
[[1, 2, 3, 4, 5, 6]],
|
||||
]
|
||||
for args in args_list:
|
||||
s_check = torch._UntypedStorage(*args, device=device)
|
||||
s = torch._UntypedStorage(*args, device='meta')
|
||||
s_check = torch.UntypedStorage(*args, device=device)
|
||||
s = torch.UntypedStorage(*args, device='meta')
|
||||
self._check_storage_meta(s, s_check)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@ -326,7 +326,7 @@ class TestTorchDeviceType(TestCase):
|
||||
@onlyCPU
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
def test_storage_meta_errors(self, device, dtype):
|
||||
s0 = torch._TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
|
||||
s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
|
||||
s0.cpu()
|
||||
@ -361,7 +361,7 @@ class TestTorchDeviceType(TestCase):
|
||||
s0._write_file(f, True, True, s0.element_size())
|
||||
|
||||
for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
|
||||
s1 = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
|
||||
s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'):
|
||||
s1.copy_(s0)
|
||||
@ -6444,7 +6444,7 @@ class TestTorch(TestCase):
|
||||
torch.storage._LegacyStorage()
|
||||
|
||||
for storage_class in torch._storage_classes:
|
||||
if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
|
||||
if storage_class in [torch.UntypedStorage, torch.TypedStorage]:
|
||||
continue
|
||||
|
||||
device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
|
||||
@ -6475,9 +6475,9 @@ class TestTorch(TestCase):
|
||||
s = storage_class()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
|
||||
storage_class(0, wrap_storage=s._untyped())
|
||||
storage_class(0, wrap_storage=s.untyped())
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
|
||||
with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"):
|
||||
storage_class(wrap_storage=s)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
@ -6493,40 +6493,40 @@ class TestTorch(TestCase):
|
||||
s_other_device = s.cuda()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
|
||||
storage_class(wrap_storage=s_other_device._untyped())
|
||||
storage_class(wrap_storage=s_other_device.untyped())
|
||||
|
||||
# _TypedStorage constructor errors
|
||||
# TypedStorage constructor errors
|
||||
with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
|
||||
torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
|
||||
torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
|
||||
torch._TypedStorage(wrap_storage=s._untyped())
|
||||
torch.TypedStorage(wrap_storage=s.untyped())
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
|
||||
torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
|
||||
torch.TypedStorage(wrap_storage=s.untyped(), dtype=0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
|
||||
torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
|
||||
torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
|
||||
torch._TypedStorage(wrap_storage=s, dtype=dtype)
|
||||
with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"):
|
||||
torch.TypedStorage(wrap_storage=s, dtype=dtype)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
|
||||
torch._TypedStorage(dtype=dtype, device='xla')
|
||||
torch.TypedStorage(dtype=dtype, device='xla')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if storage_class in quantized_storages:
|
||||
with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
|
||||
torch._TypedStorage(dtype=dtype, device='cuda')
|
||||
torch.TypedStorage(dtype=dtype, device='cuda')
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
|
||||
torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
|
||||
torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
|
||||
torch._TypedStorage(0, 0, dtype=dtype, device=device)
|
||||
torch.TypedStorage(0, 0, dtype=dtype, device=device)
|
||||
|
||||
if isinstance(s, torch._TypedStorage):
|
||||
s_other = torch._TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
|
||||
if isinstance(s, torch.TypedStorage):
|
||||
s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'cannot set item'):
|
||||
s.fill_(s_other)
|
||||
|
||||
@ -1146,7 +1146,7 @@ static PyObject* THPVariable_set_(
|
||||
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
|
||||
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
|
||||
"Expected a Storage of type ", self.dtype(),
|
||||
" or an _UntypedStorage, but got type ", storage_scalar_type,
|
||||
" or an UntypedStorage, but got type ", storage_scalar_type,
|
||||
" for argument 1 'storage'");
|
||||
auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
@ -1162,7 +1162,7 @@ static PyObject* THPVariable_set_(
|
||||
at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
|
||||
TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
|
||||
"Expected a Storage of type ", self.dtype(),
|
||||
" or an _UntypedStorage, but got type ", storage_scalar_type,
|
||||
" or an UntypedStorage, but got type ", storage_scalar_type,
|
||||
" for argument 1 'storage'");
|
||||
auto dispatch_set_ = [](const Tensor& self,
|
||||
Storage source,
|
||||
|
||||
@ -696,8 +696,8 @@ def gen_pyi(
|
||||
"def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
|
||||
],
|
||||
"set_": [
|
||||
"def set_(self, storage: Union[Storage, _TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, _TypedStorage]) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...",
|
||||
],
|
||||
"split": [
|
||||
"def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
|
||||
|
||||
@ -13,7 +13,7 @@ from typing_extensions import Literal
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt
|
||||
from torch.storage import _TypedStorage
|
||||
from torch.storage import TypedStorage
|
||||
|
||||
import builtins
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ __all__ = [
|
||||
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
|
||||
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
||||
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
|
||||
'_TypedStorage',
|
||||
'TypedStorage', 'UntypedStorage',
|
||||
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
||||
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
|
||||
'lobpcg', 'use_deterministic_algorithms',
|
||||
@ -656,10 +656,10 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
|
||||
################################################################################
|
||||
|
||||
from ._tensor import Tensor
|
||||
from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
|
||||
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
|
||||
|
||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||
# dtype, use torch.storage._TypedStorage directly.
|
||||
# dtype, use torch.storage.TypedStorage directly.
|
||||
|
||||
class ByteStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
@ -747,11 +747,11 @@ class QUInt2x4Storage(_LegacyStorage):
|
||||
return torch.quint2x4
|
||||
|
||||
_storage_classes = {
|
||||
_UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
|
||||
UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
|
||||
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
|
||||
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
|
||||
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
|
||||
_TypedStorage
|
||||
TypedStorage
|
||||
}
|
||||
|
||||
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
||||
|
||||
@ -19,8 +19,8 @@ def _save_storages(importer, obj):
|
||||
importers = sys_importer
|
||||
|
||||
def persistent_id(obj):
|
||||
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
storage = obj._storage
|
||||
@ -66,11 +66,11 @@ def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dt
|
||||
|
||||
if typename == "storage":
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
# stop wrapping with TypedStorage
|
||||
storage = serialized_storages[data[0]]
|
||||
dtype = serialized_dtypes[data[0]]
|
||||
return torch.storage._TypedStorage(
|
||||
wrap_storage=storage._untyped(), dtype=dtype
|
||||
return torch.storage.TypedStorage(
|
||||
wrap_storage=storage.untyped(), dtype=dtype
|
||||
)
|
||||
|
||||
if typename == "reduce_deploy":
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
|
||||
import torch._prims_common as utils
|
||||
import torch.library
|
||||
from torch import _TypedStorage, Tensor
|
||||
from torch import Tensor, TypedStorage
|
||||
from torch._C import _get_default_device
|
||||
from torch._prims_common import (
|
||||
check,
|
||||
|
||||
@ -1303,7 +1303,7 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
|
||||
|
||||
|
||||
def check_in_bounds_for_storage(
|
||||
a: torch._TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
|
||||
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
|
||||
):
|
||||
"""
|
||||
Determines if the given shape, strides, and offset are valid for the given storage.
|
||||
|
||||
@ -158,10 +158,10 @@ class Tensor(torch._C._TensorBase):
|
||||
f"Unsupported qscheme {self.qscheme()} in deepcopy"
|
||||
)
|
||||
# TODO: Once we decide to break serialization FC, no longer
|
||||
# need to wrap with _TypedStorage
|
||||
# need to wrap with TypedStorage
|
||||
new_tensor = torch._utils._rebuild_qtensor(
|
||||
torch.storage._TypedStorage(
|
||||
wrap_storage=new_storage._untyped(), dtype=self.dtype
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=new_storage.untyped(), dtype=self.dtype
|
||||
),
|
||||
self.storage_offset(),
|
||||
self.size(),
|
||||
@ -255,7 +255,7 @@ class Tensor(torch._C._TensorBase):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.storage, (self,), self)
|
||||
|
||||
return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
|
||||
return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
|
||||
|
||||
def _reduce_ex_internal(self, proto):
|
||||
check_serializing_named_tensor(self)
|
||||
@ -324,10 +324,10 @@ class Tensor(torch._C._TensorBase):
|
||||
f"Serialization is not supported for tensors of type {self.qscheme()}"
|
||||
)
|
||||
# TODO: Once we decide to break serialization FC, no longer
|
||||
# need to wrap with _TypedStorage
|
||||
# need to wrap with TypedStorage
|
||||
args_qtensor = (
|
||||
torch.storage._TypedStorage(
|
||||
wrap_storage=self.storage()._untyped(), dtype=self.dtype
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=self.storage().untyped(), dtype=self.dtype
|
||||
),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
@ -382,10 +382,10 @@ class Tensor(torch._C._TensorBase):
|
||||
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
|
||||
else:
|
||||
# TODO: Once we decide to break serialization FC, no longer
|
||||
# need to wrap with _TypedStorage
|
||||
# need to wrap with TypedStorage
|
||||
args = (
|
||||
torch.storage._TypedStorage(
|
||||
wrap_storage=self.storage()._untyped(), dtype=self.dtype
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=self.storage().untyped(), dtype=self.dtype
|
||||
),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
|
||||
@ -77,9 +77,11 @@ def _cuda(self, device=None, non_blocking=False, **kwargs):
|
||||
values = torch.Tensor._values(self).cuda(device, non_blocking)
|
||||
return new_type(indices, values, self.size())
|
||||
else:
|
||||
return torch._UntypedStorage(
|
||||
untyped_storage = torch.UntypedStorage(
|
||||
self.size(), device=torch.device("cuda")
|
||||
).copy_(self, non_blocking)
|
||||
)
|
||||
untyped_storage.copy_(self, non_blocking)
|
||||
return untyped_storage
|
||||
|
||||
|
||||
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
@ -138,11 +140,11 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
|
||||
|
||||
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
|
||||
# be a _TypedStorage
|
||||
# be a TypedStorage
|
||||
def _rebuild_tensor(storage, storage_offset, size, stride):
|
||||
# first construct a tensor with the correct dtype/device
|
||||
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped().device)
|
||||
return t.set_(storage._untyped(), storage_offset, size, stride)
|
||||
t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
|
||||
return t.set_(storage.untyped(), storage_offset, size, stride)
|
||||
|
||||
|
||||
def _rebuild_tensor_v2(
|
||||
@ -251,7 +253,7 @@ def _rebuild_wrapper_subclass(
|
||||
|
||||
|
||||
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
|
||||
# be a _TypedStorage
|
||||
# be a TypedStorage
|
||||
def _rebuild_qtensor(
|
||||
storage,
|
||||
storage_offset,
|
||||
|
||||
@ -94,10 +94,10 @@ PyTypeObject* loadTypedStorageTypeObject() {
|
||||
TORCH_INTERNAL_ASSERT(storage_module && PyModule_Check(storage_module));
|
||||
|
||||
PyObject* typed_storage_obj =
|
||||
PyObject_GetAttrString(storage_module, "_TypedStorage");
|
||||
PyObject_GetAttrString(storage_module, "TypedStorage");
|
||||
TORCH_INTERNAL_ASSERT(typed_storage_obj && PyType_Check(typed_storage_obj));
|
||||
return reinterpret_cast<PyTypeObject*>(
|
||||
PyObject_GetAttrString(storage_module, "_TypedStorage"));
|
||||
PyObject_GetAttrString(storage_module, "TypedStorage"));
|
||||
}
|
||||
|
||||
PyTypeObject* getTypedStorageTypeObject() {
|
||||
@ -125,7 +125,7 @@ at::Storage createStorageGetType(
|
||||
if (is_typed_storage) {
|
||||
// NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
|
||||
// `_storage`, so we must decrement them. The refcounts will still stay
|
||||
// nonzero since the `_TypedStorage` maintains a reference.
|
||||
// nonzero since the `TypedStorage` maintains a reference.
|
||||
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
||||
TORCH_INTERNAL_ASSERT(dtype_obj);
|
||||
Py_DECREF(dtype_obj);
|
||||
|
||||
@ -387,7 +387,7 @@ bool THPStorage_init(PyObject* module) {
|
||||
}
|
||||
|
||||
void THPStorage_postInit(PyObject* module) {
|
||||
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
|
||||
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
|
||||
if (!THPStorageClass)
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
#include <torch/csrc/Types.h>
|
||||
|
||||
#define THPStorageStr "torch._UntypedStorage"
|
||||
#define THPStorageStr "torch.UntypedStorage"
|
||||
#define THPStorageBaseStr "StorageBase"
|
||||
|
||||
struct THPStorage {
|
||||
|
||||
@ -138,7 +138,7 @@ static PyObject* THPStorage_resize_(PyObject* _self, PyObject* number_arg) {
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"_UntypedStorage.resize_: got unexpected device type ",
|
||||
"UntypedStorage.resize_: got unexpected device type ",
|
||||
device_type);
|
||||
}
|
||||
Py_INCREF(self);
|
||||
|
||||
@ -361,7 +361,7 @@ Tensor internal_new_from_data(
|
||||
!is_typed_storage || storage_scalar_type == scalar_type,
|
||||
"Expected a Storage of type ",
|
||||
scalar_type,
|
||||
" or an _UntypedStorage, but got ",
|
||||
" or an UntypedStorage, but got ",
|
||||
storage_scalar_type);
|
||||
tensor = at::empty(
|
||||
sizes,
|
||||
@ -642,7 +642,7 @@ Tensor legacy_tensor_generic_ctor_new(
|
||||
storage_scalar_type == scalar_type,
|
||||
"Expected a Storage of type ",
|
||||
scalar_type,
|
||||
" or an _UntypedStorage, but got type ",
|
||||
" or an UntypedStorage, but got type ",
|
||||
storage_scalar_type,
|
||||
" for argument 1 'storage'");
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
|
||||
|
||||
t = torch._utils._rebuild_tensor(
|
||||
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
|
||||
torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype),
|
||||
tensor_offset, tensor_size, tensor_stride)
|
||||
|
||||
if tensor_cls == torch.nn.parameter.Parameter:
|
||||
@ -298,7 +298,7 @@ def storage_from_cache(cls, key):
|
||||
storage_ref = shared_cache.get(key)
|
||||
if storage_ref is None:
|
||||
return None
|
||||
return torch._UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
|
||||
return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
|
||||
|
||||
|
||||
def rebuild_storage_fd(cls, df, size):
|
||||
@ -315,15 +315,15 @@ def rebuild_storage_fd(cls, df, size):
|
||||
|
||||
|
||||
def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
||||
storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle)
|
||||
storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(cls, handle)
|
||||
if storage is not None:
|
||||
return storage._shared_decref()
|
||||
if dtype is None:
|
||||
storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
||||
storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
||||
else:
|
||||
byte_size = size * torch._utils._element_size(dtype)
|
||||
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
||||
storage = torch._TypedStorage(
|
||||
untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
||||
storage = torch.TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=dtype)
|
||||
shared_cache[handle] = StorageWeakRef(storage)
|
||||
@ -334,16 +334,16 @@ def rebuild_storage_empty(cls):
|
||||
return cls()
|
||||
|
||||
def rebuild_typed_storage(storage, dtype):
|
||||
return torch.storage._TypedStorage(wrap_storage=storage, dtype=dtype)
|
||||
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype)
|
||||
|
||||
# Use for torch.storage._TypedStorage
|
||||
# Use for torch.storage.TypedStorage
|
||||
def reduce_typed_storage(storage):
|
||||
return (rebuild_typed_storage, (storage._storage, storage.dtype))
|
||||
|
||||
def rebuild_typed_storage_child(storage, storage_type):
|
||||
return storage_type(wrap_storage=storage)
|
||||
|
||||
# Use for child classes of torch.storage._TypedStorage, like torch.FloatStorage
|
||||
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
|
||||
def reduce_typed_storage_child(storage):
|
||||
return (rebuild_typed_storage_child, (storage._storage, type(storage)))
|
||||
|
||||
@ -355,7 +355,7 @@ def reduce_storage(storage):
|
||||
metadata = storage._share_filename_cpu_()
|
||||
cache_key = metadata[1]
|
||||
rebuild = rebuild_storage_filename
|
||||
if isinstance(storage, torch._TypedStorage):
|
||||
if isinstance(storage, torch.TypedStorage):
|
||||
metadata += (storage.dtype,)
|
||||
storage._shared_incref()
|
||||
elif storage.size() == 0:
|
||||
@ -377,12 +377,12 @@ def init_reductions():
|
||||
ForkingPickler.register(torch.cuda.Event, reduce_event)
|
||||
|
||||
for t in torch._storage_classes:
|
||||
if t.__name__ == '_UntypedStorage':
|
||||
if t.__name__ == 'UntypedStorage':
|
||||
ForkingPickler.register(t, reduce_storage)
|
||||
else:
|
||||
ForkingPickler.register(t, reduce_typed_storage_child)
|
||||
|
||||
ForkingPickler.register(torch.storage._TypedStorage, reduce_typed_storage)
|
||||
ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
|
||||
|
||||
for t in torch._tensor_classes:
|
||||
ForkingPickler.register(t, reduce_tensor)
|
||||
|
||||
@ -35,7 +35,7 @@ class DirectoryReader(object):
|
||||
def get_storage_from_record(self, name, numel, dtype):
|
||||
filename = f"{self.directory}/{name}"
|
||||
nbytes = torch._utils._element_size(dtype) * numel
|
||||
storage = cast(Storage, torch._UntypedStorage)
|
||||
storage = cast(Storage, torch.UntypedStorage)
|
||||
return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
|
||||
|
||||
def has_record(self, path):
|
||||
|
||||
@ -883,8 +883,8 @@ class PackageExporter:
|
||||
)
|
||||
|
||||
def _persistent_id(self, obj):
|
||||
if torch.is_storage(obj) or isinstance(obj, torch.storage._TypedStorage):
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
untyped_storage = obj._storage
|
||||
@ -892,7 +892,7 @@ class PackageExporter:
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
storage_numel = obj.size()
|
||||
|
||||
elif isinstance(obj, torch._UntypedStorage):
|
||||
elif isinstance(obj, torch.UntypedStorage):
|
||||
untyped_storage = obj
|
||||
storage_type = normalize_storage_type(type(storage))
|
||||
storage_numel = storage.nbytes()
|
||||
|
||||
@ -234,9 +234,9 @@ class PackageImporter(Importer):
|
||||
)
|
||||
storage = loaded_storages[key]
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
return torch.storage._TypedStorage(
|
||||
wrap_storage=storage._untyped(), dtype=dtype
|
||||
# stop wrapping with TypedStorage
|
||||
return torch.storage.TypedStorage(
|
||||
wrap_storage=storage.untyped(), dtype=dtype
|
||||
)
|
||||
elif typename == "reduce_package":
|
||||
# to fix BC breaking change, objects on this load path
|
||||
|
||||
@ -157,7 +157,7 @@ def _cuda_deserialize(obj, location):
|
||||
device = validate_cuda_device(location)
|
||||
if getattr(obj, "_torch_load_uninitialized", False):
|
||||
with torch.cuda.device(device):
|
||||
return torch._UntypedStorage(obj.nbytes(), device=torch.device(location))
|
||||
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
|
||||
else:
|
||||
return obj.cuda(device)
|
||||
|
||||
@ -171,7 +171,7 @@ register_package(20, _cuda_tag, _cuda_deserialize)
|
||||
register_package(21, _mps_tag, _mps_deserialize)
|
||||
|
||||
|
||||
def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
|
||||
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
|
||||
for _, tagger, _ in _package_registry:
|
||||
location = tagger(storage)
|
||||
if location:
|
||||
@ -423,10 +423,10 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
"for correctness upon loading.")
|
||||
return ('module', obj, source_file, source)
|
||||
|
||||
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
|
||||
storage: torch._UntypedStorage
|
||||
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
||||
storage: torch.UntypedStorage
|
||||
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
@ -436,7 +436,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
dtype = obj.dtype
|
||||
storage_numel = obj.size()
|
||||
|
||||
elif isinstance(obj, torch._UntypedStorage):
|
||||
elif isinstance(obj, torch.UntypedStorage):
|
||||
storage = obj
|
||||
storage_dtype = torch.uint8
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
@ -476,8 +476,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
# effectively saving nbytes in this case. We'll be able to load it
|
||||
# and the tensor back up with no problems in _this_ and future
|
||||
# versions of pytorch, but in older versions, here's the problem:
|
||||
# the storage will be loaded up as a _UntypedStorage, and then the
|
||||
# FloatTensor will loaded and the _UntypedStorage will be assigned to
|
||||
# the storage will be loaded up as a UntypedStorage, and then the
|
||||
# FloatTensor will loaded and the UntypedStorage will be assigned to
|
||||
# it. Since the storage dtype does not match the tensor dtype, this
|
||||
# will cause an error. If we reverse the list, like `[tensor,
|
||||
# storage]`, then we will save the `tensor.storage()` as a faked
|
||||
@ -485,7 +485,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
# dtype-specific numel count that old versions expect. `tensor`
|
||||
# will be able to load up properly in old versions, pointing to
|
||||
# a FloatStorage. However, `storage` is still being translated to
|
||||
# a _UntypedStorage, and it will try to resolve to the same
|
||||
# a UntypedStorage, and it will try to resolve to the same
|
||||
# FloatStorage that `tensor` contains. This will also cause an
|
||||
# error. It doesn't seem like there's any way around this.
|
||||
# Probably, we just cannot maintain FC for the legacy format if the
|
||||
@ -552,9 +552,9 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
# see
|
||||
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
||||
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
||||
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
|
||||
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
||||
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
@ -817,11 +817,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
args = pickle_module.load(f, **pickle_load_args)
|
||||
key, location, storage_type = args
|
||||
dtype = storage_type.dtype
|
||||
obj = cast(Storage, torch._UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
|
||||
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
|
||||
obj = restore_location(obj, location)
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
deserialized_objects[key] = torch.storage._TypedStorage(
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[key] = torch.storage.TypedStorage(
|
||||
wrap_storage=obj,
|
||||
dtype=dtype)
|
||||
|
||||
@ -831,8 +831,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
element_size = torch._utils._element_size(root.dtype)
|
||||
offset_bytes = offset * element_size
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
deserialized_objects[target_cdata] = torch.storage._TypedStorage(
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
|
||||
wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
|
||||
dtype=root.dtype)
|
||||
|
||||
@ -879,11 +879,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
|
||||
if root_key not in deserialized_objects:
|
||||
obj = cast(Storage, torch._UntypedStorage(nbytes))
|
||||
obj = cast(Storage, torch.UntypedStorage(nbytes))
|
||||
obj._torch_load_uninitialized = True
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
deserialized_objects[root_key] = torch.storage._TypedStorage(
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[root_key] = torch.storage.TypedStorage(
|
||||
wrap_storage=restore_location(obj, location),
|
||||
dtype=dtype)
|
||||
|
||||
@ -894,8 +894,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
view_size_bytes = view_size * torch._utils._element_size(dtype)
|
||||
if view_key not in deserialized_objects:
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
deserialized_objects[view_key] = torch.storage._TypedStorage(
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[view_key] = torch.storage.TypedStorage(
|
||||
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
|
||||
dtype=dtype)
|
||||
res = deserialized_objects[view_key]
|
||||
@ -1005,10 +1005,10 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
|
||||
def load_tensor(dtype, numel, key, location):
|
||||
name = f'data/{key}'
|
||||
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped()
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with _TypedStorage
|
||||
loaded_storages[key] = torch.storage._TypedStorage(
|
||||
# stop wrapping with TypedStorage
|
||||
loaded_storages[key] = torch.storage.TypedStorage(
|
||||
wrap_storage=restore_location(storage, location),
|
||||
dtype=dtype)
|
||||
|
||||
@ -1020,7 +1020,7 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
|
||||
assert typename == 'storage', \
|
||||
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, numel = data
|
||||
if storage_type is torch._UntypedStorage:
|
||||
if storage_type is torch.UntypedStorage:
|
||||
dtype = torch.uint8
|
||||
else:
|
||||
dtype = storage_type.dtype
|
||||
|
||||
113
torch/storage.py
113
torch/storage.py
@ -13,7 +13,7 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
|
||||
T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
|
||||
class _StorageBase(object):
|
||||
_cdata: Any
|
||||
is_sparse: bool = False
|
||||
@ -117,14 +117,14 @@ class _StorageBase(object):
|
||||
def cpu(self):
|
||||
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
||||
if self.device.type != 'cpu':
|
||||
return torch._UntypedStorage(self.size()).copy_(self, False)
|
||||
return torch.UntypedStorage(self.size()).copy_(self, False)
|
||||
else:
|
||||
return self
|
||||
|
||||
def mps(self):
|
||||
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
||||
if self.device.type != 'mps':
|
||||
return torch._UntypedStorage(self.size(), device="mps").copy_(self, False)
|
||||
return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
|
||||
else:
|
||||
return self
|
||||
|
||||
@ -222,11 +222,11 @@ class _StorageBase(object):
|
||||
else:
|
||||
return cls._new_using_fd_cpu(size)
|
||||
|
||||
def _untyped(self):
|
||||
def untyped(self):
|
||||
return self
|
||||
|
||||
|
||||
class _UntypedStorage(torch._C.StorageBase, _StorageBase):
|
||||
class UntypedStorage(torch._C.StorageBase, _StorageBase):
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
if self.device.type == 'meta':
|
||||
raise NotImplementedError("Not available for 'meta' device type")
|
||||
@ -248,9 +248,9 @@ _StorageBase.cuda = _cuda # type: ignore[assignment]
|
||||
def _dtype_to_storage_type_map():
|
||||
# NOTE: We should no longer add dtypes to this map. This map
|
||||
# is only used for BC/FC with older PyTorch versions. Going forward,
|
||||
# new dtypes of _TypedStorage should not translate to a legacy
|
||||
# <type>Storage class. Instead, new dtypes of _TypedStorage should
|
||||
# be serialized as an _UntypedStorage paired with a torch.dtype
|
||||
# new dtypes of TypedStorage should not translate to a legacy
|
||||
# <type>Storage class. Instead, new dtypes of TypedStorage should
|
||||
# be serialized as an UntypedStorage paired with a torch.dtype
|
||||
return {
|
||||
torch.double: 'DoubleStorage',
|
||||
torch.float: 'FloatStorage',
|
||||
@ -297,7 +297,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
return tmp_tensor.storage()._untyped()
|
||||
return tmp_tensor.storage().untyped()
|
||||
|
||||
def _isint(x):
|
||||
if HAS_NUMPY:
|
||||
@ -305,7 +305,7 @@ def _isint(x):
|
||||
else:
|
||||
return isinstance(x, int)
|
||||
|
||||
class _TypedStorage:
|
||||
class TypedStorage:
|
||||
is_sparse = False
|
||||
|
||||
dtype: torch.dtype
|
||||
@ -318,7 +318,7 @@ class _TypedStorage:
|
||||
if cls == torch.storage._LegacyStorage:
|
||||
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
|
||||
|
||||
if cls == _TypedStorage:
|
||||
if cls == TypedStorage:
|
||||
return super().__new__(cls)
|
||||
|
||||
else:
|
||||
@ -328,7 +328,7 @@ class _TypedStorage:
|
||||
' * no arguments\n'
|
||||
' * (int size)\n'
|
||||
' * (Sequence data)\n'
|
||||
' * (*, _UntypedStorage wrap_storage)')
|
||||
' * (*, UntypedStorage wrap_storage)')
|
||||
|
||||
if device is not None:
|
||||
raise RuntimeError(
|
||||
@ -351,7 +351,7 @@ class _TypedStorage:
|
||||
arg_error_msg +
|
||||
f"\nArgument type not recognized: {type(args[0])}")
|
||||
|
||||
return _TypedStorage(
|
||||
return TypedStorage(
|
||||
*args,
|
||||
dtype=cls.dtype,
|
||||
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu')
|
||||
@ -363,10 +363,10 @@ class _TypedStorage:
|
||||
"\nNo positional arguments should be given when using "
|
||||
"'wrap_storage'")
|
||||
|
||||
if not isinstance(wrap_storage, torch._UntypedStorage):
|
||||
if not isinstance(wrap_storage, torch.UntypedStorage):
|
||||
raise TypeError(
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
||||
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
|
||||
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
|
||||
@ -376,19 +376,19 @@ class _TypedStorage:
|
||||
f"\nDevice of 'wrap_storage' must be {cls_device}"
|
||||
f", but got {wrap_storage.device.type}")
|
||||
|
||||
return _TypedStorage(
|
||||
return TypedStorage(
|
||||
*args,
|
||||
wrap_storage=wrap_storage,
|
||||
dtype=cls.dtype)
|
||||
|
||||
def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
|
||||
arg_error_msg = (
|
||||
'_TypedStorage.__init__ received an invalid combination '
|
||||
'TypedStorage.__init__ received an invalid combination '
|
||||
'of arguments. Expected one of:\n'
|
||||
' * (*, torch.device device, torch.dtype dtype)\n'
|
||||
' * (int size, *, torch.device device, torch.dtype dtype)\n'
|
||||
' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
|
||||
' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
|
||||
' * (*, UntypedStorage wrap_storage, torch.dtype dtype)')
|
||||
|
||||
if wrap_storage is not None:
|
||||
if len(args) != 0:
|
||||
@ -414,10 +414,10 @@ class _TypedStorage:
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
if not isinstance(wrap_storage, torch._UntypedStorage):
|
||||
if not isinstance(wrap_storage, torch.UntypedStorage):
|
||||
raise TypeError(
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
||||
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
|
||||
|
||||
self._storage = wrap_storage
|
||||
|
||||
@ -430,11 +430,11 @@ class _TypedStorage:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
|
||||
if len(args) == 0:
|
||||
self._storage = torch._UntypedStorage(device=device)
|
||||
self._storage = torch.UntypedStorage(device=device)
|
||||
|
||||
elif len(args) == 1:
|
||||
if _isint(args[0]):
|
||||
self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device)
|
||||
self._storage = torch.UntypedStorage(int(args[0]) * self.element_size(), device=device)
|
||||
elif isinstance(args[0], collections.abc.Sequence):
|
||||
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
|
||||
else:
|
||||
@ -452,14 +452,15 @@ class _TypedStorage:
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def _untyped(self):
|
||||
def untyped(self):
|
||||
"""Returns the internal :class:`torch.UntypedStorage`"""
|
||||
return self._storage
|
||||
|
||||
def _new_wrapped_storage(self, untyped_storage):
|
||||
assert type(untyped_storage) == torch._UntypedStorage
|
||||
assert type(untyped_storage) == torch.UntypedStorage
|
||||
|
||||
if type(self) == _TypedStorage:
|
||||
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
|
||||
if type(self) == TypedStorage:
|
||||
return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
|
||||
else:
|
||||
return type(self)(wrap_storage=untyped_storage)
|
||||
|
||||
@ -505,7 +506,7 @@ class _TypedStorage:
|
||||
torch.qint8: torch.int8
|
||||
}
|
||||
tmp_dtype = interpret_dtypes[self.dtype]
|
||||
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage(
|
||||
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
|
||||
wrap_storage=self._storage,
|
||||
dtype=tmp_dtype))
|
||||
else:
|
||||
@ -517,12 +518,12 @@ class _TypedStorage:
|
||||
if self.device.type == 'meta':
|
||||
raise NotImplementedError("Not available for 'meta' device type")
|
||||
|
||||
# NOTE: Before _TypedStorage existed, indexing with a slice used to be
|
||||
# NOTE: Before TypedStorage existed, indexing with a slice used to be
|
||||
# possible for <type>Storage objects. However, it would return
|
||||
# a storage view, which would be a hassle to implement in _TypedStorage,
|
||||
# a storage view, which would be a hassle to implement in TypedStorage,
|
||||
# so it was disabled
|
||||
if isinstance(idx, slice):
|
||||
raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__')
|
||||
raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
|
||||
elif not isinstance(idx, int):
|
||||
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
|
||||
|
||||
@ -534,7 +535,7 @@ class _TypedStorage:
|
||||
torch.qint32: torch.int32,
|
||||
torch.qint8: torch.int8
|
||||
}
|
||||
return _TypedStorage(
|
||||
return TypedStorage(
|
||||
wrap_storage=self._storage,
|
||||
dtype=interpret_dtypes[self.dtype])[idx]
|
||||
|
||||
@ -543,7 +544,7 @@ class _TypedStorage:
|
||||
return tmp_tensor[idx_wrapped].item()
|
||||
|
||||
def copy_(self, source: T, non_blocking: bool = None):
|
||||
self._storage.copy_(source._untyped(), non_blocking)
|
||||
self._storage.copy_(source.untyped(), non_blocking)
|
||||
return self
|
||||
|
||||
def nbytes(self):
|
||||
@ -564,7 +565,7 @@ class _TypedStorage:
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
|
||||
cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def element_size(self):
|
||||
@ -596,7 +597,7 @@ class _TypedStorage:
|
||||
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
|
||||
|
||||
def __sizeof__(self):
|
||||
return super(_TypedStorage, self).__sizeof__() + self.nbytes()
|
||||
return super(TypedStorage, self).__sizeof__() + self.nbytes()
|
||||
|
||||
def clone(self):
|
||||
"""Returns a copy of this storage"""
|
||||
@ -631,8 +632,8 @@ class _TypedStorage:
|
||||
if device is None:
|
||||
device = 'cpu'
|
||||
device = torch.device(device)
|
||||
untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device)
|
||||
return _TypedStorage(
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device)
|
||||
return TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=self.dtype)
|
||||
|
||||
@ -666,34 +667,34 @@ class _TypedStorage:
|
||||
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs):
|
||||
return _UntypedStorage._free_weak_ref(*args, **kwargs)
|
||||
return UntypedStorage._free_weak_ref(*args, **kwargs)
|
||||
|
||||
def _weak_ref(self, *args, **kwargs):
|
||||
return self._storage._weak_ref(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
|
||||
if cls == _TypedStorage:
|
||||
if cls == TypedStorage:
|
||||
dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
device = torch.device('cpu' if device is None else device)
|
||||
if device.type != 'cpu':
|
||||
raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}')
|
||||
untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}')
|
||||
untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
|
||||
else:
|
||||
if dtype is not None or len(args) == 5:
|
||||
raise RuntimeError((
|
||||
"from_buffer: 'dtype' can only be specified in "
|
||||
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
|
||||
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
|
||||
if device is not None:
|
||||
raise RuntimeError((
|
||||
"from_buffer: 'device' can only be specified in "
|
||||
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
|
||||
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
|
||||
|
||||
dtype = cls.dtype
|
||||
untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
|
||||
return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
|
||||
return TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
|
||||
|
||||
def _to(self, dtype):
|
||||
if not isinstance(dtype, torch.dtype):
|
||||
@ -770,9 +771,9 @@ class _TypedStorage:
|
||||
shared (bool): whether to share memory
|
||||
size (int): number of elements in the storage
|
||||
"""
|
||||
if cls == _TypedStorage:
|
||||
if cls == TypedStorage:
|
||||
raise RuntimeError('from_file can only be called on derived classes')
|
||||
untyped_storage: _UntypedStorage = _UntypedStorage.from_file(
|
||||
untyped_storage: UntypedStorage = UntypedStorage.from_file(
|
||||
filename,
|
||||
shared,
|
||||
size * torch._utils._element_size(cls.dtype))
|
||||
@ -781,7 +782,7 @@ class _TypedStorage:
|
||||
|
||||
@classmethod
|
||||
def _expired(cls, *args, **kwargs):
|
||||
return _UntypedStorage._expired(*args, **kwargs)
|
||||
return UntypedStorage._expired(*args, **kwargs)
|
||||
|
||||
def is_pinned(self):
|
||||
return self._storage.is_pinned()
|
||||
@ -803,7 +804,7 @@ class _TypedStorage:
|
||||
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs):
|
||||
return torch._UntypedStorage._new_shared_cuda(*args, **kwargs)
|
||||
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
|
||||
|
||||
def _share_filename_cpu_(self, *args, **kwargs):
|
||||
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
|
||||
@ -815,7 +816,7 @@ class _TypedStorage:
|
||||
|
||||
@classmethod
|
||||
def _release_ipc_counter(cls, *args, device=None, **kwargs):
|
||||
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
|
||||
def _shared_incref(self, *args, **kwargs):
|
||||
return self._storage._shared_incref(*args, **kwargs)
|
||||
@ -840,33 +841,33 @@ class _TypedStorage:
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
_TypedStorage.type.__doc__ = _type.__doc__
|
||||
_TypedStorage.cuda.__doc__ = _cuda.__doc__
|
||||
TypedStorage.type.__doc__ = _type.__doc__
|
||||
TypedStorage.cuda.__doc__ = _cuda.__doc__
|
||||
|
||||
class _LegacyStorageMeta(type):
|
||||
dtype: torch.dtype
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if type(instance) == _TypedStorage:
|
||||
if type(instance) == TypedStorage:
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
|
||||
return False
|
||||
|
||||
class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
|
||||
class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
|
||||
@classmethod
|
||||
def _new_shared(cls, size):
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
untyped_storage = torch._UntypedStorage._new_shared(size * cls().element_size())
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size())
|
||||
return cls(wrap_storage=untyped_storage)
|
||||
|
||||
@classmethod
|
||||
def _release_ipc_counter(cls, *args, **kwargs):
|
||||
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _new_shared_filename(cls, manager, obj, size):
|
||||
bytes_size = size * torch._utils._element_size(cls.dtype)
|
||||
return cls(wrap_storage=torch._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
|
||||
return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
|
||||
|
||||
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
|
||||
try:
|
||||
|
||||
@ -2369,7 +2369,7 @@ class TestCase(expecttest.TestCase):
|
||||
),
|
||||
sequence_types=(
|
||||
Sequence,
|
||||
torch.storage._TypedStorage,
|
||||
torch.storage.TypedStorage,
|
||||
Sequential,
|
||||
ModuleList,
|
||||
ParameterList,
|
||||
|
||||
Reference in New Issue
Block a user