From ee28b865ee9c87cce4db0011987baf8d125cc857 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 8 Nov 2022 18:11:01 +0000 Subject: [PATCH] Deprecate TypedStorage, its derived classes, and all of their public methods (#85303) Part of #85302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303 Approved by: https://github.com/ezyang --- docs/source/storage.rst | 4 + test/test_autograd.py | 2 +- test/test_cuda.py | 2 +- test/test_torch.py | 121 +++++++++ test/test_view_ops.py | 2 +- .../templates/python_variable_methods.cpp | 2 +- torch/__init__.py | 87 +++++- torch/_deploy.py | 2 +- torch/_dynamo/optimizations/analysis.py | 2 +- torch/_dynamo/optimizations/distributed.py | 4 +- torch/_dynamo/optimizations/training.py | 4 +- torch/_prims/__init__.py | 4 +- torch/_subclasses/fake_tensor.py | 8 +- torch/_subclasses/fake_utils.py | 6 +- torch/_subclasses/meta_utils.py | 16 +- torch/_tensor.py | 34 ++- torch/_utils.py | 4 +- torch/csrc/DynamicTypes.cpp | 2 +- torch/cuda/__init__.py | 63 ++++- torch/cuda/_dynamo_graphs.py | 4 +- .../_shard/checkpoint/filesystem.py | 2 +- torch/distributed/distributed_c10d.py | 2 +- torch/distributed/fsdp/_utils.py | 14 +- torch/distributed/fsdp/flat_param.py | 16 +- .../pipeline/sync/_balance/profile.py | 2 +- torch/distributed/pipeline/sync/stream.py | 2 +- torch/fx/passes/reinplace.py | 36 ++- torch/multiprocessing/reductions.py | 21 +- torch/overrides.py | 1 + torch/package/package_exporter.py | 2 +- torch/package/package_importer.py | 6 +- torch/serialization.py | 39 +-- torch/storage.py | 252 +++++++++++++----- torch/testing/_comparison.py | 27 +- torch/testing/_internal/schema_check_mode.py | 6 +- torch/utils/bundled_inputs.py | 4 +- torch/utils/data/_utils/collate.py | 2 +- 37 files changed, 631 insertions(+), 176 deletions(-) diff --git a/docs/source/storage.rst b/docs/source/storage.rst index 28cf4444fbc9..84fed2f659a7 100644 --- a/docs/source/storage.rst +++ b/docs/source/storage.rst @@ -22,6 +22,10 @@ holds the data as an untyped array of bytes. Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`, which stores all of the data that the :class:`torch.Tensor` views. +.. warning:: + All storage classes except for :class:`torch.UntypedStorage` will be removed + in the future, and :class:`torch.UntypedStorage` will be used in all cases. + .. autoclass:: torch.TypedStorage :members: :undoc-members: diff --git a/test/test_autograd.py b/test/test_autograd.py index 7df0b1ddae38..dd3ecf3323d3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6805,8 +6805,8 @@ for shape in [(1,), ()]: with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): a = torch.ones(5, requires_grad=True) - warnings.simplefilter('always') with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') y = a * a # should raise two warnings from a being saved twice self.assertEqual(len(w), 2) diff --git a/test/test_cuda.py b/test/test_cuda.py index 9128ea093715..9ecafc45103b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,7 +595,7 @@ class TestCuda(TestCase): self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) - self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage)) + self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) diff --git a/test/test_torch.py b/test/test_torch.py index 2247d18285d5..82d0807d81a7 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6470,6 +6470,127 @@ class TestTorch(TestCase): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) + # Test that internal versions of functions related to TypedStorage do not + # produce a deprecation warning + def test_typed_storage_internal_no_warning(self): + s0 = torch.FloatStorage(10) + s0_untyped = s0.untyped() + t0 = torch.randn(10) + + funcs = [ + lambda: torch.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cpu', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s0_untyped, + dtype=s0.dtype, + _internal=True), + lambda: torch.FloatStorage._dtype, + lambda: s0._resize_(20), + lambda: s0._size(), + lambda: s0._untyped_storage, + lambda: s0._is_shared(), + lambda: s0._share_memory_(), + lambda: s0._pickle_storage_type(), + lambda: s0._setitem(slice(0, s0._size()), 1), + lambda: s0._element_size(), + lambda: s0._deepcopy({}), + lambda: s0._data_ptr(), + lambda: s0._nbytes(), + lambda: t0._typed_storage(), + ] + + if torch.cuda.is_available(): + s1 = torch.cuda.FloatStorage(10) + s1_untyped = s1.untyped() + t1 = torch.randn(10, device='cuda') + + funcs += [ + lambda: torch.cuda.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cuda', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s1_untyped, + dtype=s1.dtype, + _internal=True), + lambda: torch.cuda.FloatStorage._dtype, + lambda: s1._resize_(20), + lambda: s1._size(), + lambda: s1._untyped_storage, + lambda: s1._is_shared(), + lambda: s1._share_memory_(), + lambda: s1._pickle_storage_type(), + lambda: s1._setitem(slice(0, s1._size()), 1), + lambda: s1._element_size(), + lambda: s1._deepcopy({}), + lambda: s1._data_ptr(), + lambda: s1._nbytes(), + lambda: t1._typed_storage(), + ] + + # Check that each of the TypedStorage internal function calls do not + # produce a deprecation warning + for f in funcs: + with warnings.catch_warnings(): + warnings.filterwarnings('error', "TypedStorage is deprecated") + f() + + # Test that public functions related to TypedStorage produce a deprecation + # warning + def test_typed_storage_deprecation_warning(self): + s0 = torch.FloatStorage(10) + funcs = [ + lambda: torch.FloatStorage(), + lambda: torch.FloatStorage.dtype, + lambda: s0.fill_(0), + lambda: s0.is_cuda, + lambda: s0.untyped(), + lambda: len(s0), + lambda: s0[0], + ] + + if torch.cuda.is_available(): + s1 = torch.cuda.FloatStorage(10) + funcs += [ + lambda: torch.cuda.FloatStorage(), + lambda: torch.cuda.FloatStorage.dtype, + lambda: s1.fill_(0), + lambda: s1.is_cuda, + lambda: s1.untyped(), + lambda: len(s1), + lambda: s1[0], + ] + + # Check that each of the TypedStorage function calls produce a warning + # if warnings are reset between each + for f in funcs: + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + f() + self.assertEqual(len(w), 1) + warning = w[0].message + self.assertTrue(warning, DeprecationWarning) + self.assertTrue(re.search( + '^TypedStorage is deprecated', + str(warning))) + + # Check that only one warning is raised from calling multiple + # TypedStorage functions if warnings are not reset between each + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + for f in funcs: + f() + self.assertEqual(len(w), 1) + warning = w[0].message + self.assertTrue(warning, DeprecationWarning) + self.assertTrue(re.search( + '^TypedStorage is deprecated', + str(warning))) + def test_from_file(self): def assert_with_filename(filename): size = 10000 diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 3c5987e65ae7..c4729557c416 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -102,7 +102,7 @@ class TestViewOps(TestCase): # Note: only validates storage on native device types # because some accelerators, like XLA, do not expose storage if base.device.type == 'cpu' or base.device.type == 'cuda': - if base.storage().data_ptr() != other.storage().data_ptr(): + if base._storage().data_ptr() != other._storage().data_ptr(): return False return True diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index e3c0a8b987bd..2cd847b73405 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "storage"); + return handle_torch_function(self, "_storage"); } auto& self_ = THPVariable_Unpack(self); return createPyObject(self_.storage()); diff --git a/torch/__init__.py b/torch/__init__.py index 1a645f53a8a2..ae55f5975542 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -709,7 +709,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf']) ################################################################################ from ._tensor import Tensor -from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage +from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal # NOTE: New Storage classes should never be added. When adding a new # dtype, use torch.storage.TypedStorage directly. @@ -717,86 +717,171 @@ from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage class ByteStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.uint8 class DoubleStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.double class FloatStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.float class HalfStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.half class LongStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.long class IntStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int class ShortStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.short class CharStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int8 class BoolStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bool class BFloat16Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bfloat16 class ComplexDoubleStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cdouble class ComplexFloatStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cfloat class QUInt8Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint8 class QInt8Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint8 class QInt32Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint32 class QUInt4x2Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint4x2 class QUInt2x4Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint2x4 _storage_classes = { diff --git a/torch/_deploy.py b/torch/_deploy.py index 53769538b6c1..30c022eac879 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -23,7 +23,7 @@ def _save_storages(importer, obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - storage = obj._storage + storage = obj._untyped_storage dtype = obj.dtype else: storage = obj diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py index 0af70bfa9581..b3f6ed79eb06 100644 --- a/torch/_dynamo/optimizations/analysis.py +++ b/torch/_dynamo/optimizations/analysis.py @@ -27,7 +27,7 @@ class ShapeAliasingAndMutationProp(ShapeProp): def tensor_alias_group(self, value: torch.Tensor): """Assign a unique identifier to the storage of a given tensor""" - storage = StorageWeakRef(value.storage()) + storage = StorageWeakRef(value._typed_storage()) alias_group = self.storage_to_alias_group.get(storage) if alias_group is None: alias_group = next(self.make_alias_group) diff --git a/torch/_dynamo/optimizations/distributed.py b/torch/_dynamo/optimizations/distributed.py index bde786979fcf..b71d85c4e34f 100644 --- a/torch/_dynamo/optimizations/distributed.py +++ b/torch/_dynamo/optimizations/distributed.py @@ -157,7 +157,7 @@ class DDPOptimizer: for name, p in target.named_parameters(): param = target.get_parameter(name) if p.requires_grad and not self._ignore_parameter(param): - buckets[0].size += p.storage().nbytes() + buckets[0].size += p._storage().nbytes() buckets[0].params.append(f"{node.target}_{name}") buckets[0].param_ids.append(id(param)) elif node.op == "get_attr": @@ -165,7 +165,7 @@ class DDPOptimizer: if maybe_param.requires_grad and not self._ignore_parameter( maybe_param ): - buckets[0].size += maybe_param.storage().nbytes() + buckets[0].size += maybe_param._storage().nbytes() buckets[0].params.append(node.target) buckets[0].param_ids.append(id(maybe_param)) diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py index 588956a898f4..af673a2b2c1e 100644 --- a/torch/_dynamo/optimizations/training.py +++ b/torch/_dynamo/optimizations/training.py @@ -381,7 +381,7 @@ def find_input_mutations(g): mutated_inputs = set() for n in g.nodes: if n.op == "placeholder": - inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx) + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) input_idx += 1 elif n.op == "call_function": if n.target is operator.getitem: @@ -402,7 +402,7 @@ def find_input_mutations(g): # TODO: not correct for args that contain tensors in a struct # like list mutated_inputs |= inputs[ - StorageWeakRef(meta_fk(argument.meta).storage()) + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) ] # TODO: error on unrecognized nodes return mutated_inputs diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 6d40e1071fb5..c40960a22445 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1158,7 +1158,9 @@ def _as_strided_meta( # as_strided to shapes with no elements are trivially valid, so it's OK pass elif isinstance(a, torch.Tensor): - utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset) + utils.check_in_bounds_for_storage( + a._typed_storage(), size, stride, storage_offset + ) return TensorMeta(a, shape=size, strides=stride) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 796b15fedf10..fa58ce23c443 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -156,7 +156,7 @@ class FakeTensorConverter(object): # const_tensor.add_(torch.rand([1])) # all aliases of it must become no longer const assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None - weak_st = StorageWeakRef(fake_tensor.constant.storage()) + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) # we need a map from a weak storage to all of its corresponding # constant tensors. python doesn't have the weak value equivalent @@ -168,7 +168,7 @@ class FakeTensorConverter(object): def invalidate_constant_aliases(self, tensor): assert not isinstance(tensor, FakeTensor) - weak_st = StorageWeakRef(tensor.storage()) + weak_st = StorageWeakRef(tensor._typed_storage()) if weak_st not in self.constant_storage_mapping: return @@ -1043,7 +1043,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): if not e.is_sparse: - storages.add(e.storage()._cdata) + storages.add(e._typed_storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will @@ -1053,7 +1053,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and not e.is_sparse - and e.storage()._cdata in storages + and e._typed_storage()._cdata in storages ): raise orig_not_implemented_exception diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 37ff260c9bd3..d23b12ca8440 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -18,12 +18,12 @@ aten = torch.ops.aten def outputs_alias_inputs(outputs, inputs): input_storages = { - inp.storage()._cdata + inp._typed_storage()._cdata for inp in tree_flatten_only(torch.Tensor, inputs) if torch._C._has_storage(inp) } return any( - torch._C._has_storage(out) and out.storage()._cdata in input_storages + torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages for out in tree_flatten_only(torch.Tensor, outputs) ) @@ -38,7 +38,7 @@ def output_alias_each_other(outputs): for out in tree_flatten_only(torch.Tensor, outputs): if not torch._C._has_storage(out): continue - stor = out.storage()._cdata + stor = out._typed_storage()._cdata if stor in storages: return True storages.add(stor) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 7e2039f1764f..081f7aa632f9 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -143,7 +143,7 @@ class MetaConverter: if t.is_sparse: weak_st = None else: - weak_st = StorageWeakRef(t.storage()) + weak_st = StorageWeakRef(t._typed_storage()) tensor_ref_key = WeakTensorRefKey(t) def del_ten(): @@ -179,13 +179,9 @@ class MetaConverter: # Use a Weak Ref to s in order to not leak memory swr = StorageWeakRef(s) if swr not in self.storage_memo: - self.storage_memo[swr] = ( - callback( - lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") - ) - .storage() - .untyped() - ) + self.storage_memo[swr] = callback( + lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") + )._storage() return self.storage_memo[swr] # This function assumes that it's possible to do the conversion @@ -362,7 +358,7 @@ class MetaConverter: # format here r = r.clone(memory_format=torch.preserve_format) - s = t.storage().untyped() + s = t._storage() swr = StorageWeakRef(s) if ( swr not in self.storage_memo @@ -370,7 +366,7 @@ class MetaConverter: and r.storage_offset() == storage_offset ): # You're normal and happy, install the fresh storage into the memo - self.storage_memo[swr] = r.storage().untyped() + self.storage_memo[swr] = r._storage() else: # You're in crazy town; somehow you gave us a tensor # that wasn't a view, but had nonzero storage offset, diff --git a/torch/_tensor.py b/torch/_tensor.py index d0af241c8a22..8ac1ac1eb736 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -132,7 +132,7 @@ class Tensor(torch._C._TensorBase): "different type." ) else: - new_storage = self.storage().__deepcopy__(memo) + new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute quantizer_params: Union[ @@ -163,7 +163,9 @@ class Tensor(torch._C._TensorBase): # need to wrap with TypedStorage new_tensor = torch._utils._rebuild_qtensor( torch.storage.TypedStorage( - wrap_storage=new_storage.untyped(), dtype=self.dtype + wrap_storage=new_storage._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), self.size(), @@ -257,7 +259,17 @@ class Tensor(torch._C._TensorBase): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage, (self,), self) - return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) + torch.storage._warn_typed_storage_removal() + return self._typed_storage() + + # For internal use only, to avoid raising deprecation warning + def _typed_storage(self): + _storage = self._storage() + if isinstance(_storage, torch.TypedStorage): + _storage = _storage._untyped_storage + return torch.TypedStorage( + wrap_storage=_storage, dtype=self.dtype, _internal=True + ) def _reduce_ex_internal(self, proto): check_serializing_named_tensor(self) @@ -331,7 +343,9 @@ class Tensor(torch._C._TensorBase): # need to wrap with TypedStorage args_qtensor = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), tuple(self.size()), @@ -389,7 +403,9 @@ class Tensor(torch._C._TensorBase): # need to wrap with TypedStorage args = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), tuple(self.size()), @@ -607,7 +623,7 @@ class Tensor(torch._C._TensorBase): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.is_shared, (self,), self) - return self.storage().is_shared() + return self._typed_storage()._is_shared() def share_memory_(self): r"""Moves the underlying storage to shared memory. @@ -617,7 +633,7 @@ class Tensor(torch._C._TensorBase): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.share_memory_, (self,), self) - self.storage().share_memory_() + self._typed_storage()._share_memory_() return self def __reversed__(self): @@ -1059,7 +1075,9 @@ class Tensor(torch._C._TensorBase): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage_type, (self,), self) - return self.storage()._get_legacy_storage_class() + torch.storage._warn_typed_storage_removal() + + return self._typed_storage()._get_legacy_storage_class() def refine_names(self, *names): r"""Refines the dimension names of :attr:`self` according to :attr:`names`. diff --git a/torch/_utils.py b/torch/_utils.py index 8a539d75f565..f178cfbaea4a 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -143,8 +143,8 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): # be a TypedStorage def _rebuild_tensor(storage, storage_offset, size, stride): # first construct a tensor with the correct dtype/device - t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device) - return t.set_(storage.untyped(), storage_offset, size, stride) + t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) + return t.set_(storage._untyped_storage, storage_offset, size, stride) def _rebuild_tensor_v2( diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index b3021ffe0d8d..93bb37017ce0 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -135,7 +135,7 @@ at::Storage createStorageGetType( TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj)); scalar_type = reinterpret_cast(dtype_obj)->scalar_type; - untyped_storage_obj = PyObject_GetAttrString(obj, "_storage"); + untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage"); TORCH_INTERNAL_ASSERT(untyped_storage_obj); Py_DECREF(untyped_storage_obj); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 93fa8cf07ac2..a684f2291de2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -737,11 +737,12 @@ class _CudaBase(object): __new__ = _lazy_new -from torch.storage import _LegacyStorage +from torch.storage import _LegacyStorage, _warn_typed_storage_removal class _CudaLegacyStorage(_LegacyStorage): @classmethod def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() raise RuntimeError('from_buffer: Not available for CUDA storage') @classmethod @@ -755,61 +756,121 @@ class _CudaLegacyStorage(_LegacyStorage): class ByteStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.uint8 class DoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.double class FloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.float class HalfStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.half class LongStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.long class IntStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int class ShortStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.short class CharStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int8 class BoolStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bool class BFloat16Storage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bfloat16 class ComplexDoubleStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cdouble class ComplexFloatStorage(_CudaLegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cfloat del _LegacyStorage diff --git a/torch/cuda/_dynamo_graphs.py b/torch/cuda/_dynamo_graphs.py index 07ebed6fadf0..6c577c317776 100644 --- a/torch/cuda/_dynamo_graphs.py +++ b/torch/cuda/_dynamo_graphs.py @@ -89,7 +89,7 @@ def find_input_mutations(g): mutated_inputs = set() for n in g.nodes: if n.op == 'placeholder': - inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx) + inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx) input_idx += 1 elif n.op == 'call_function': if n.target is operator.getitem: @@ -109,7 +109,7 @@ def find_input_mutations(g): if mut_arg: # TODO: not correct for args that contain tensors in a struct # like list - mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())] + mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())] # TODO: error on unrecognized nodes return mutated_inputs diff --git a/torch/distributed/_shard/checkpoint/filesystem.py b/torch/distributed/_shard/checkpoint/filesystem.py index ece9000b3ddf..9788853d9aa6 100644 --- a/torch/distributed/_shard/checkpoint/filesystem.py +++ b/torch/distributed/_shard/checkpoint/filesystem.py @@ -51,7 +51,7 @@ DEFAULT_SUFIX = ".distcp" def _trim(tensor: torch.Tensor) -> torch.Tensor: tensor = tensor.detach().cpu() - if tensor.storage().size() != tensor.numel(): + if tensor._typed_storage()._size() != tensor.numel(): tensor = tensor.clone() return tensor diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 32b0949a3e34..41d0ee21d3e3 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1896,7 +1896,7 @@ def all_gather_multigpu( def _object_to_tensor(obj, device): f = io.BytesIO() _pickler(f).dump(obj) - byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. # Otherwise, it will casue 100X slowdown. # See: https://github.com/pytorch/pytorch/issues/65696 diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py index bf7937451a29..5efb376e6645 100644 --- a/torch/distributed/fsdp/_utils.py +++ b/torch/distributed/fsdp/_utils.py @@ -69,14 +69,14 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: bool: ``True`` if this method allocated storage and ``False`` if the storage was already allocated. """ - already_allocated = tensor.storage().size() == size.numel() + already_allocated = tensor._typed_storage()._size() == size.numel() if not already_allocated: - tensor_storage_size = tensor.storage().size() + tensor_storage_size = tensor._typed_storage()._size() p_assert( tensor_storage_size == 0, f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", ) - tensor.storage().resize_(size.numel()) + tensor._typed_storage()._resize_(size.numel()) return not already_allocated @@ -89,23 +89,23 @@ def _free_storage(tensor: torch.Tensor) -> bool: bool: ``True`` if the method freed the storage and ``False`` if the storage was already freed. """ - already_freed = tensor.storage().size() == 0 + already_freed = tensor._typed_storage()._size() == 0 if not already_freed: p_assert( tensor.storage_offset() == 0, "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" f"storage offset: {tensor.storage_offset()}\n" - f"storage size: {tensor.storage().size()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" f"tensor shape: {tensor.shape}", ) - tensor.storage().resize_(0) + tensor._typed_storage()._resize_(0) return not already_freed def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: """Returns if ``x`` and ``y`` share the same storage.""" # NOTE: CPU and GPU tensors are ensured to have different data pointers. - return x.storage().data_ptr() == y.storage().data_ptr() + return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr() def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index ee693648fb34..0978f0875a28 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -493,7 +493,7 @@ class FlatParamHandle: flat_param.storage_offset() == 0, "The `FlatParameter` is not the sole occupant of its storage", ) - orig_storage = flat_param.storage() + orig_storage = flat_param._typed_storage() sharded_flat_param, numel_padded = FlatParamHandle._get_shard( flat_param, self.rank, self.world_size ) @@ -501,8 +501,8 @@ class FlatParamHandle: start = sharded_flat_param.numel() * self.rank end = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive self._init_shard_metadata(numel_padded, start, end) - if orig_storage.size() > 0: - orig_storage.resize_(0) + if orig_storage._size() > 0: + orig_storage._resize_(0) if self._use_orig_params: self._use_sharded_views() @@ -838,7 +838,7 @@ class FlatParamHandle: return False unsharded_flat_param = self._get_padded_unsharded_flat_param() already_unsharded = ( - unsharded_flat_param.storage().size() == unsharded_flat_param.numel() + unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel() ) return not already_unsharded @@ -1141,9 +1141,9 @@ class FlatParamHandle: # the padded unsharded flattened parameter as expected # NOTE: This check is not strictly needed for correctness but is a # useful sanity check since the tensor should only be used internally. - unpadded_storage_ptr = self.flat_param.storage().data_ptr() + unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr() padded_storage_ptr = ( - self._get_padded_unsharded_flat_param().storage().data_ptr() + self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr() ) p_assert( unpadded_storage_ptr == padded_storage_ptr, @@ -1824,7 +1824,7 @@ class FlatParamHandle: @staticmethod def _check_storage_freed(tensor: Tensor): - storage_size: int = tensor.storage().size() + storage_size: int = tensor._typed_storage()._size() p_assert( storage_size == 0, f"Expects storage to be freed but got storage with size {storage_size}", @@ -1832,7 +1832,7 @@ class FlatParamHandle: @staticmethod def _check_storage_allocated(tensor: Tensor): - storage_size: int = tensor.storage().size() + storage_size: int = tensor._typed_storage()._size() p_assert(storage_size > 0, "Expects storage to be allocated") def _check_low_precision_shard(self): diff --git a/torch/distributed/pipeline/sync/_balance/profile.py b/torch/distributed/pipeline/sync/_balance/profile.py index 9759a4b6262a..fa1a0c06a8e3 100644 --- a/torch/distributed/pipeline/sync/_balance/profile.py +++ b/torch/distributed/pipeline/sync/_balance/profile.py @@ -107,7 +107,7 @@ def profile_sizes( latent_size = memory_after - memory_before # Analyze size of parameters. - param_size = sum(p.storage().nbytes() for p in layer.parameters()) + param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters()) # Combine size of parameters and activations with normalize scales. size = latent_size * latent_scale + param_size * param_scale diff --git a/torch/distributed/pipeline/sync/stream.py b/torch/distributed/pipeline/sync/stream.py index 56b699343739..59fedf865a42 100644 --- a/torch/distributed/pipeline/sync/stream.py +++ b/torch/distributed/pipeline/sync/stream.py @@ -104,7 +104,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: # # Issue: https://github.com/pytorch/pytorch/issues/27366 # - tensor = tensor.new_empty([0]).set_(tensor.storage()) + tensor = tensor.new_empty([0]).set_(tensor._typed_storage()) # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type] diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index ff24ef97f545..86986a85acc8 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -100,8 +100,8 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter): # Assert here that this is actually the case, and their storages are the same. assert isinstance(node.meta['fake_result'], FakeTensor) assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result'].storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage()) + view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) assert view_storage == base_storage return result @@ -176,7 +176,7 @@ _VIEW_INVERSE_MAP = { def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): def _add_if_tensor(x, set_): if isinstance(x, FakeTensor): - set_.add(StorageWeakRef(x.storage())) + set_.add(StorageWeakRef(x._typed_storage())) nodes_used_after = set() for t in tensor_aliases: @@ -452,7 +452,7 @@ def reinplace(gm, *sample_args): # Useful debug printing # def _print(x): # if isinstance(x, FakeTensor): - # print(f'fake_result: {StorageWeakRef(x.storage()).cdata}') + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') # for n in gm.graph.nodes: # print(n.format_node()) @@ -468,7 +468,10 @@ def reinplace(gm, *sample_args): # so we know not to re-inplace them. # NOTE: later, we'll need to add an optimization for fully recovering performance # on programs that mutate inputs. - input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder') + input_storages = set( + StorageWeakRef( + node.meta['fake_result']._typed_storage() + ) for node in gm.graph.nodes if node.op == 'placeholder') # We also need to know for a given node, what are all of its aliasing nodes. @@ -478,7 +481,7 @@ def reinplace(gm, *sample_args): # Tree-mapping because some ops can return lists of tensors. def _add_to_map(x): if isinstance(x, FakeTensor): - storage_to_nodes[StorageWeakRef(x.storage())].add(n) + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) tree_map(_add_to_map, n.meta['fake_result']) # inplace-ify functional ops, subject to the constraints written below. @@ -529,7 +532,7 @@ def reinplace(gm, *sample_args): # Step 1b: ensure that the op we're trying to re-inplace isn't a program input self_arg_name = self_arg.name - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) if self_arg_storage in input_storages: # TODO: later, add the optimization for handling `copy_()` calls in the graph. continue @@ -539,7 +542,7 @@ def reinplace(gm, *sample_args): # so we prevent re-inplacing in this case. continue - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage()) + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) self_aliases = storage_to_nodes[self_arg_storage] # First, we find all later usages of any of the aliases of self_arg. @@ -594,7 +597,7 @@ def reinplace(gm, *sample_args): # Hmm... morally I think we also want to keep the `fake_result` metadata # up to date here, but I'm not sure how easy it is to do. # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage()) + curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) @@ -624,8 +627,14 @@ def reinplace(gm, *sample_args): old_flattened_res, _ = tree_flatten(old.meta['fake_result']) node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result']) - old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor)) - node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor)) + old_res_storage = set( + StorageWeakRef( + x._typed_storage() + ) for x in old_flattened_res if isinstance(x, FakeTensor)) + node_res_storage = set( + StorageWeakRef( + x._typed_storage() + ) for x in node_flattened_res if isinstance(x, FakeTensor)) # This will happen if we're updating a view op, e.g. # e.g. replacing @@ -639,7 +648,10 @@ def reinplace(gm, *sample_args): # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: new_flattened_res, _ = tree_flatten(new.meta['fake_result']) - new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor)) + new_res_storage = set( + StorageWeakRef( + x._typed_storage() + ) for x in new_flattened_res if isinstance(x, FakeTensor)) assert len(new_res_storage) == 1 (old_ref,) = old_res_storage (new_ref,) = new_res_storage diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 403b28d6a63c..4fcccb47685c 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -113,7 +113,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required): # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: - storage = storage_cls(0, dtype=dtype, device=storage_device) + storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) else: storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes)) if storage is None: @@ -132,8 +132,10 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, # We already ref counting this Storage, but producer needs new ref-counters to be released. storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device) + _storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage + t = torch._utils._rebuild_tensor( - torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype), + torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), tensor_offset, tensor_size, tensor_stride) if tensor_cls == torch.nn.parameter.Parameter: @@ -147,7 +149,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset, def reduce_tensor(tensor): - storage = tensor.storage() + storage = tensor._typed_storage() if tensor.requires_grad and not tensor.is_leaf: raise RuntimeError("Cowardly refusing to serialize non-leaf tensor which requires_grad, " @@ -248,7 +250,7 @@ def reduce_tensor(tensor): # eliminated it so that we could just use tensor views to implement the same # thing. # - if storage.is_cuda: + if storage._untyped_storage.device.type == 'cuda': (device, handle, storage_size_bytes, @@ -325,7 +327,8 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None): untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) storage = torch.TypedStorage( wrap_storage=untyped_storage, - dtype=dtype) + dtype=dtype, + _internal=True) shared_cache[handle] = StorageWeakRef(storage) return storage._shared_decref() @@ -334,18 +337,18 @@ def rebuild_storage_empty(cls): return cls() def rebuild_typed_storage(storage, dtype): - return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype) + return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) # Use for torch.storage.TypedStorage def reduce_typed_storage(storage): - return (rebuild_typed_storage, (storage._storage, storage.dtype)) + return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) def rebuild_typed_storage_child(storage, storage_type): - return storage_type(wrap_storage=storage) + return storage_type(wrap_storage=storage, _internal=True) # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage def reduce_typed_storage_child(storage): - return (rebuild_typed_storage_child, (storage._storage, type(storage))) + return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) def reduce_storage(storage): from . import get_sharing_strategy diff --git a/torch/overrides.py b/torch/overrides.py index ce7872f9d1ab..cb4402235e1a 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -273,6 +273,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor.to_sparse_csc, Tensor.to_sparse_bsr, Tensor.to_sparse_bsc, + Tensor._typed_storage, Tensor._reduce_ex_internal, Tensor._fix_weakref, Tensor._make_wrapper_subclass, diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index a95f105d2474..7f6af38468e2 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -887,7 +887,7 @@ class PackageExporter: if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - untyped_storage = obj._storage + untyped_storage = obj._untyped_storage storage_type_str = obj.pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj.size() diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 7bf945c70c0b..3db37128b03b 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -208,14 +208,14 @@ class PackageImporter(Importer): name = f"{key}.storage" if storage_context.has_storage(name): - storage = storage_context.get_storage(name, dtype).storage() + storage = storage_context.get_storage(name, dtype)._typed_storage() else: tensor = self.zip_reader.get_storage_from_record( ".data/" + name, size, dtype ) if isinstance(self.zip_reader, torch._C.PyTorchFileReader): storage_context.add_storage(name, tensor) - storage = tensor.storage() + storage = tensor._typed_storage() loaded_storages[key] = restore_location(storage, location) def persistent_load(saved_id): @@ -239,7 +239,7 @@ class PackageImporter(Importer): # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage return torch.storage.TypedStorage( - wrap_storage=storage.untyped(), dtype=dtype + wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True ) elif typename == "reduce_package": # to fix BC breaking change, objects on this load path diff --git a/torch/serialization.py b/torch/serialization.py index 53d060019408..d123a955ad96 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -469,12 +469,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted - storage = obj._storage + storage = obj._untyped_storage storage_dtype = obj.dtype - storage_type_str = obj.pickle_storage_type() + storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) dtype = obj.dtype - storage_numel = obj.size() + storage_numel = obj._size() elif isinstance(obj, torch.UntypedStorage): storage = obj @@ -597,11 +597,11 @@ def _save(obj, zip_file, pickle_module, pickle_protocol): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted - storage = obj._storage + storage = obj._untyped_storage storage_dtype = obj.dtype - storage_type_str = obj.pickle_storage_type() + storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) - storage_numel = obj.size() + storage_numel = obj._size() else: storage = obj @@ -893,14 +893,15 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): for i in range(num_storages): args = pickle_module.load(f, **pickle_load_args) key, location, storage_type = args - dtype = storage_type.dtype + dtype = storage_type._dtype obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) obj = restore_location(obj, location) # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage deserialized_objects[key] = torch.storage.TypedStorage( wrap_storage=obj, - dtype=dtype) + dtype=dtype, + _internal=True) storage_views = pickle_module.load(f, **pickle_load_args) for target_cdata, root_cdata, offset, numel in storage_views: @@ -910,8 +911,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): # TODO: Once we decide to break serialization FC, we can # stop wrapping with TypedStorage deserialized_objects[target_cdata] = torch.storage.TypedStorage( - wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size], - dtype=root.dtype) + wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], + dtype=root.dtype, + _internal=True) tar.extract('tensors', path=tmpdir) with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: @@ -927,7 +929,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) storage_offset, = struct.unpack(' self.size()) or (idx < -self.size()): + if (idx > self._size()) or (idx < -self._size()): raise IndexError( f'index {idx} out of range for storage of size {self.size()}') if idx > 0: return idx else: - return idx % self.size() + return idx % self._size() else: - if (idx >= self.size()) or (idx < -self.size()): + if (idx >= self._size()) or (idx < -self._size()): raise IndexError( f'index {idx} out of range for storage of size {self.size()}') - return idx % self.size() + return idx % self._size() def __setitem__(self, idx, value): + _warn_typed_storage_removal() + return self._setitem(idx, value) + + def _setitem(self, idx, value): if not isinstance(idx, (int, slice)): raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") if torch.is_storage(value): @@ -506,16 +536,22 @@ class TypedStorage: torch.qint8: torch.int8 } tmp_dtype = interpret_dtypes[self.dtype] - tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage( - wrap_storage=self._storage, - dtype=tmp_dtype)) + tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device) + tmp_tensor.set_(TypedStorage( + wrap_storage=self._untyped_storage, + dtype=tmp_dtype, + _internal=True)) else: - tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) + tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self) tmp_tensor[idx] = value def __getitem__(self, idx): - if self.device.type == 'meta': + _warn_typed_storage_removal() + return self._getitem(idx) + + def _getitem(self, idx): + if self._untyped_storage.device.type == 'meta': raise NotImplementedError("Not available for 'meta' device type") # NOTE: Before TypedStorage existed, indexing with a slice used to be @@ -536,21 +572,32 @@ class TypedStorage: torch.qint8: torch.int8 } return TypedStorage( - wrap_storage=self._storage, - dtype=interpret_dtypes[self.dtype])[idx] + wrap_storage=self._untyped_storage, + dtype=interpret_dtypes[self.dtype], + _internal=True)._getitem(idx) idx_wrapped = self._maybe_wrap_index(idx) - tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self) + tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self) return tmp_tensor[idx_wrapped].item() def copy_(self, source: T, non_blocking: bool = None): - self._storage.copy_(source.untyped(), non_blocking) + _warn_typed_storage_removal() + if isinstance(source, TypedStorage): + self._untyped_storage.copy_(source._untyped_storage, non_blocking) + else: + self._untyped_storage.copy_(source, non_blocking) return self def nbytes(self): - return self._storage.nbytes() + _warn_typed_storage_removal() + return self._nbytes() + + # For internal use only, to avoid deprecation warning + def _nbytes(self): + return self._untyped_storage.nbytes() def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: + _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -560,21 +607,29 @@ class TypedStorage: return '.'.join([self.__module__, type(self).__name__]) else: - return self._storage.type(dtype, non_blocking) + return self._untyped_storage.type(dtype, non_blocking) def cuda(self, device=None, non_blocking=False, **kwargs) -> T: + _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") - cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs) + cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs) return self._new_wrapped_storage(cuda_storage) def element_size(self): + _warn_typed_storage_removal() + return self._element_size() + + # For internal use only, to avoid deprecation warning + def _element_size(self): return torch._utils._element_size(self.dtype) def get_device(self) -> int: - return self._storage.get_device() + _warn_typed_storage_removal() + return self._untyped_storage.get_device() def __str__(self): + _warn_typed_storage_removal() info_str = ( f'[{torch.typename(self)}(dtype={self.dtype}, ' f'device={self.device}) of size {len(self)}]') @@ -585,35 +640,48 @@ class TypedStorage: return data_str + '\n' + info_str def __repr__(self): + _warn_typed_storage_removal() return str(self) def __iter__(self): + _warn_typed_storage_removal() return iter(map(lambda i: self[i], range(self.size()))) def __copy__(self): - return self._new_wrapped_storage(copy.copy(self._storage)) + _warn_typed_storage_removal() + return self._new_wrapped_storage(copy.copy(self._untyped_storage)) def __deepcopy__(self, memo): - return self._new_wrapped_storage(copy.deepcopy(self._storage, memo)) + _warn_typed_storage_removal() + return self._deepcopy(memo) + + # For internal use only, to avoid deprecation warning + def _deepcopy(self, memo): + return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) def __sizeof__(self): + _warn_typed_storage_removal() return super(TypedStorage, self).__sizeof__() + self.nbytes() def clone(self): """Returns a copy of this storage""" - return self._new_wrapped_storage(self._storage.clone()) + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.clone()) def tolist(self): """Returns a list containing the elements of this storage""" + _warn_typed_storage_removal() return list(self) def cpu(self): """Returns a CPU copy of this storage if it's not already on the CPU""" - return self._new_wrapped_storage(self._storage.cpu()) + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.cpu()) def pin_memory(self): """Coppies the storage to pinned memory, if it's not already pinned.""" - return self._new_wrapped_storage(self._storage.pin_memory()) + _warn_typed_storage_removal() + return self._new_wrapped_storage(self._untyped_storage.pin_memory()) def share_memory_(self): """Moves the storage to shared memory. @@ -624,7 +692,12 @@ class TypedStorage: Returns: self """ - self._storage.share_memory_() + _warn_typed_storage_removal() + return self._share_memory_() + + # For internal use only, to avoid deprecation warning + def _share_memory_(self): + self._untyped_storage.share_memory_() return self def _new_shared(self, size, *, device=None): @@ -632,25 +705,37 @@ class TypedStorage: if device is None: device = 'cpu' device = torch.device(device) - untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device) + untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device) return TypedStorage( wrap_storage=untyped_storage, - dtype=self.dtype) + dtype=self.dtype, + _internal=True) @property def _cdata(self): - return self._storage._cdata + return self._untyped_storage._cdata @property def device(self): - return self._storage.device + _warn_typed_storage_removal() + return self._untyped_storage.device def size(self): + _warn_typed_storage_removal() + return self._size() + + # For internal use only, to avoid deprecation warning + def _size(self): # NB: don't indirect through __len__, as that requires # an int to be returned - return self.nbytes() // self.element_size() + return self._untyped_storage.nbytes() // self._element_size() def pickle_storage_type(self): + _warn_typed_storage_removal() + return self._pickle_storage_type() + + # For internal use only, to avoid deprecation warning + def _pickle_storage_type(self): try: return _dtype_to_storage_type_map()[self.dtype] except KeyError: @@ -662,20 +747,35 @@ class TypedStorage: return (_load_from_bytes, (b.getvalue(),)) def data_ptr(self): - return self._storage.data_ptr() + _warn_typed_storage_removal() + return self._data_ptr() + + # For internal use only, to avoid deprecation warning + def _data_ptr(self): + return self._untyped_storage.data_ptr() def resize_(self, size): - self._storage.resize_(size * self.element_size()) + _warn_typed_storage_removal() + self._resize_(size) + + # For internal use only, to avoid deprecation warning + def _resize_(self, size): + self._untyped_storage.resize_(size * self._element_size()) @classmethod def _free_weak_ref(cls, *args, **kwargs): return UntypedStorage._free_weak_ref(*args, **kwargs) def _weak_ref(self, *args, **kwargs): - return self._storage._weak_ref(*args, **kwargs) + return self._untyped_storage._weak_ref(*args, **kwargs) @classmethod - def from_buffer(cls, *args, dtype=None, device=None, **kwargs): + def from_buffer(cls, *args, **kwargs): + _warn_typed_storage_removal() + return cls._from_buffer(*args, **kwargs) + + @classmethod + def _from_buffer(cls, *args, dtype=None, device=None, **kwargs): if cls == TypedStorage: dtype = torch.get_default_dtype() if dtype is None else dtype device = torch.device('cpu' if device is None else device) @@ -693,65 +793,80 @@ class TypedStorage: "from_buffer: 'device' can only be specified in " "UntypedStorage.from_buffer and TypedStorage.from_buffer")) - dtype = cls.dtype + dtype = cls._dtype untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs) - return TypedStorage(wrap_storage=untyped_storage, dtype=dtype) + return TypedStorage( + wrap_storage=untyped_storage, + dtype=dtype, + _internal=True) def _to(self, dtype): if not isinstance(dtype, torch.dtype): raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") - storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage() + storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage() if storage.data_ptr() == self.data_ptr(): storage = storage.clone() return storage def double(self): """Casts this storage to double type""" + _warn_typed_storage_removal() return self._to(torch.double) def float(self): """Casts this storage to float type""" + _warn_typed_storage_removal() return self._to(torch.float) def half(self): """Casts this storage to half type""" + _warn_typed_storage_removal() return self._to(torch.half) def long(self): """Casts this storage to long type""" + _warn_typed_storage_removal() return self._to(torch.long) def int(self): """Casts this storage to int type""" + _warn_typed_storage_removal() return self._to(torch.int) def short(self): """Casts this storage to short type""" + _warn_typed_storage_removal() return self._to(torch.short) def char(self): """Casts this storage to char type""" + _warn_typed_storage_removal() return self._to(torch.int8) def byte(self): """Casts this storage to byte type""" + _warn_typed_storage_removal() return self._to(torch.uint8) def bool(self): """Casts this storage to bool type""" + _warn_typed_storage_removal() return self._to(torch.bool) def bfloat16(self): """Casts this storage to bfloat16 type""" + _warn_typed_storage_removal() return self._to(torch.bfloat16) def complex_double(self): """Casts this storage to complex double type""" + _warn_typed_storage_removal() return self._to(torch.cdouble) def complex_float(self): """Casts this storage to complex float type""" + _warn_typed_storage_removal() return self._to(torch.cfloat) @classmethod @@ -773,6 +888,7 @@ class TypedStorage: shared (bool): whether to share memory size (int): number of elements in the storage """ + _warn_typed_storage_removal() if cls == TypedStorage: raise RuntimeError('from_file can only be called on derived classes') untyped_storage: UntypedStorage = UntypedStorage.from_file( @@ -787,33 +903,39 @@ class TypedStorage: return UntypedStorage._expired(*args, **kwargs) def is_pinned(self): - return self._storage.is_pinned() + _warn_typed_storage_removal() + return self._untyped_storage.is_pinned() def _write_file(self, *args, **kwargs): - return self._storage._write_file(*args, **kwargs) + return self._untyped_storage._write_file(*args, **kwargs) def _set_from_file(self, *args, **kwargs): - return self._storage._set_from_file(*args, **kwargs) + return self._untyped_storage._set_from_file(*args, **kwargs) def _set_cdata(self, *args, **kwargs): - return self._storage._set_cdata(*args, **kwargs) + return self._untyped_storage._set_cdata(*args, **kwargs) def _share_cuda_(self, *args, **kwargs): - return self._storage._share_cuda_(*args, **kwargs) + return self._untyped_storage._share_cuda_(*args, **kwargs) def is_shared(self): - return self._storage.is_shared() + _warn_typed_storage_removal() + return self._is_shared() + + # For internal use only, to avoid deprecation warning + def _is_shared(self): + return self._untyped_storage.is_shared() @classmethod def _new_shared_cuda(cls, *args, **kwargs): return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) def _share_filename_cpu_(self, *args, **kwargs): - manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs) - return manager_handle, storage_handle, size // self.element_size() + manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs) + return manager_handle, storage_handle, size // self._element_size() def _shared_decref(self): - self._storage._shared_decref() + self._untyped_storage._shared_decref() return self @classmethod @@ -821,11 +943,11 @@ class TypedStorage: return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) def _shared_incref(self, *args, **kwargs): - return self._storage._shared_incref(*args, **kwargs) + return self._untyped_storage._shared_incref(*args, **kwargs) def _share_fd_cpu_(self, *args, **kwargs): - fd, size = self._storage._share_fd_cpu_(*args, **kwargs) - return fd, size // self.element_size() + fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) + return fd, size // self._element_size() def _get_legacy_storage_class(self): if self.dtype not in _dtype_to_storage_type_map(): @@ -859,7 +981,7 @@ class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): @classmethod def _new_shared(cls, size): """Creates a new storage in shared memory with the same data type""" - untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size()) + untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) return cls(wrap_storage=untyped_storage) @classmethod diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index d15cae4b1bb5..6999986f5294 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -927,9 +927,34 @@ def originate_pairs( Returns: (List[Pair]): Originated pairs. """ + if ( + isinstance(actual, torch.TypedStorage) + and isinstance(expected, torch.TypedStorage) + ): + actual_len = actual._size() + expected_len = expected._size() + if actual_len != expected_len: + raise ErrorMeta( + AssertionError, f"The length of the sequences mismatch: {actual_len} != {expected_len}", id=id + ) + + pairs = [] + for idx in range(actual_len): + pairs.extend( + originate_pairs( + actual._getitem(idx), + expected._getitem(idx), + pair_types=pair_types, + sequence_types=sequence_types, + mapping_types=mapping_types, + id=(*id, idx), + **options, + ) + ) + return pairs # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - if ( + elif ( isinstance(actual, sequence_types) and not isinstance(actual, str) and isinstance(expected, sequence_types) diff --git a/torch/testing/_internal/schema_check_mode.py b/torch/testing/_internal/schema_check_mode.py index 9d118719af6b..9fda9d95e159 100644 --- a/torch/testing/_internal/schema_check_mode.py +++ b/torch/testing/_internal/schema_check_mode.py @@ -47,7 +47,7 @@ class SchemaCheckMode(TorchDispatchMode): before.size() == after.size() and torch.allclose(before, after, equal_nan=True) and md[0] == after.stride() and - md[1] == after.storage()._cdata + md[1] == after._typed_storage()._cdata ) return False @@ -76,12 +76,12 @@ class SchemaCheckMode(TorchDispatchMode): if not type(e) == torch.Tensor: try: current = e.elem - return (deepcopy(current.stride()), current.storage()._cdata) + return (deepcopy(current.stride()), current._typed_storage()._cdata) except AttributeError as t: return None # Sparse CSR tensors do not have strides or storage elif (e.layout != torch.sparse_csr): - return (deepcopy(e.stride()), e.storage()._cdata) + return (deepcopy(e.stride()), e._typed_storage()._cdata) return None self.ops.append(func._schema.name) diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 1ca2d56616bc..4ae39733ff2e 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -391,7 +391,7 @@ def _inflate_expr( if isinstance(arg, torch.Tensor): # Small-storage tensors can just be saved directly. - if arg.storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: + if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: return arg, ref, None # Small contiguous tensors can be cloned to have small storage. # TODO: Should we do this even for non-contiguous tensors? @@ -407,7 +407,7 @@ def _inflate_expr( # TODO: Provide more useful diagnostics. raise Exception( f"Bundled input argument at position '{ref}' is " - f"a tensor with storage size {arg.storage().size()}. " + f"a tensor with storage size {arg._typed_storage().size()}. " f"You probably don't want to bundle this as an input. " ) else: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 0ba9f25c2c9d..1a00cd4514f5 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -158,7 +158,7 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[ # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel, device=elem.device) + storage = elem._typed_storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out)