Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)

Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-11-08 18:11:01 +00:00
committed by PyTorch MergeBot
parent 53ca5ad347
commit ee28b865ee
37 changed files with 631 additions and 176 deletions

View File

@ -22,6 +22,10 @@ holds the data as an untyped array of bytes.
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
which stores all of the data that the :class:`torch.Tensor` views.
.. warning::
All storage classes except for :class:`torch.UntypedStorage` will be removed
in the future, and :class:`torch.UntypedStorage` will be used in all cases.
.. autoclass:: torch.TypedStorage
:members:
:undoc-members:

View File

@ -6805,8 +6805,8 @@ for shape in [(1,), ()]:
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
a = torch.ones(5, requires_grad=True)
warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
y = a * a
# should raise two warnings from a being saved twice
self.assertEqual(len(w), 2)

View File

@ -595,7 +595,7 @@ class TestCuda(TestCase):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

View File

@ -6470,6 +6470,127 @@ class TestTorch(TestCase):
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
self.assertIs(complexdouble_storage.dtype, torch.complex128)
# Test that internal versions of functions related to TypedStorage do not
# produce a deprecation warning
def test_typed_storage_internal_no_warning(self):
s0 = torch.FloatStorage(10)
s0_untyped = s0.untyped()
t0 = torch.randn(10)
funcs = [
lambda: torch.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cpu',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s0_untyped,
dtype=s0.dtype,
_internal=True),
lambda: torch.FloatStorage._dtype,
lambda: s0._resize_(20),
lambda: s0._size(),
lambda: s0._untyped_storage,
lambda: s0._is_shared(),
lambda: s0._share_memory_(),
lambda: s0._pickle_storage_type(),
lambda: s0._setitem(slice(0, s0._size()), 1),
lambda: s0._element_size(),
lambda: s0._deepcopy({}),
lambda: s0._data_ptr(),
lambda: s0._nbytes(),
lambda: t0._typed_storage(),
]
if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
s1_untyped = s1.untyped()
t1 = torch.randn(10, device='cuda')
funcs += [
lambda: torch.cuda.FloatStorage(_internal=True),
lambda: torch.TypedStorage(
dtype=torch.float,
device='cuda',
_internal=True),
lambda: torch.TypedStorage(
wrap_storage=s1_untyped,
dtype=s1.dtype,
_internal=True),
lambda: torch.cuda.FloatStorage._dtype,
lambda: s1._resize_(20),
lambda: s1._size(),
lambda: s1._untyped_storage,
lambda: s1._is_shared(),
lambda: s1._share_memory_(),
lambda: s1._pickle_storage_type(),
lambda: s1._setitem(slice(0, s1._size()), 1),
lambda: s1._element_size(),
lambda: s1._deepcopy({}),
lambda: s1._data_ptr(),
lambda: s1._nbytes(),
lambda: t1._typed_storage(),
]
# Check that each of the TypedStorage internal function calls do not
# produce a deprecation warning
for f in funcs:
with warnings.catch_warnings():
warnings.filterwarnings('error', "TypedStorage is deprecated")
f()
# Test that public functions related to TypedStorage produce a deprecation
# warning
def test_typed_storage_deprecation_warning(self):
s0 = torch.FloatStorage(10)
funcs = [
lambda: torch.FloatStorage(),
lambda: torch.FloatStorage.dtype,
lambda: s0.fill_(0),
lambda: s0.is_cuda,
lambda: s0.untyped(),
lambda: len(s0),
lambda: s0[0],
]
if torch.cuda.is_available():
s1 = torch.cuda.FloatStorage(10)
funcs += [
lambda: torch.cuda.FloatStorage(),
lambda: torch.cuda.FloatStorage.dtype,
lambda: s1.fill_(0),
lambda: s1.is_cuda,
lambda: s1.untyped(),
lambda: len(s1),
lambda: s1[0],
]
# Check that each of the TypedStorage function calls produce a warning
# if warnings are reset between each
for f in funcs:
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))
# Check that only one warning is raised from calling multiple
# TypedStorage functions if warnings are not reset between each
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
for f in funcs:
f()
self.assertEqual(len(w), 1)
warning = w[0].message
self.assertTrue(warning, DeprecationWarning)
self.assertTrue(re.search(
'^TypedStorage is deprecated',
str(warning)))
def test_from_file(self):
def assert_with_filename(filename):
size = 10000

View File

@ -102,7 +102,7 @@ class TestViewOps(TestCase):
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base.storage().data_ptr() != other.storage().data_ptr():
if base._storage().data_ptr() != other._storage().data_ptr():
return False
return True

View File

@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "storage");
return handle_torch_function(self, "_storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage());

View File

@ -709,7 +709,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
################################################################################
from ._tensor import Tensor
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.
@ -717,86 +717,171 @@ from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.uint8
class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.double
class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.float
class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.half
class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.long
class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int
class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.short
class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int8
class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bool
class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bfloat16
class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.cdouble
class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.cfloat
class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.quint8
class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.qint8
class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.qint32
class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.quint4x2
class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.quint2x4
_storage_classes = {

View File

@ -23,7 +23,7 @@ def _save_storages(importer, obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
storage = obj._untyped_storage
dtype = obj.dtype
else:
storage = obj

View File

@ -27,7 +27,7 @@ class ShapeAliasingAndMutationProp(ShapeProp):
def tensor_alias_group(self, value: torch.Tensor):
"""Assign a unique identifier to the storage of a given tensor"""
storage = StorageWeakRef(value.storage())
storage = StorageWeakRef(value._typed_storage())
alias_group = self.storage_to_alias_group.get(storage)
if alias_group is None:
alias_group = next(self.make_alias_group)

View File

@ -157,7 +157,7 @@ class DDPOptimizer:
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p.storage().nbytes()
buckets[0].size += p._storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
@ -165,7 +165,7 @@ class DDPOptimizer:
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param.storage().nbytes()
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))

View File

@ -381,7 +381,7 @@ def find_input_mutations(g):
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx)
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
@ -402,7 +402,7 @@ def find_input_mutations(g):
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta).storage())
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs

View File

@ -1158,7 +1158,9 @@ def _as_strided_meta(
# as_strided to shapes with no elements are trivially valid, so it's OK
pass
elif isinstance(a, torch.Tensor):
utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset)
utils.check_in_bounds_for_storage(
a._typed_storage(), size, stride, storage_offset
)
return TensorMeta(a, shape=size, strides=stride)

View File

@ -156,7 +156,7 @@ class FakeTensorConverter(object):
# const_tensor.add_(torch.rand([1]))
# all aliases of it must become no longer const
assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
weak_st = StorageWeakRef(fake_tensor.constant.storage())
weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())
# we need a map from a weak storage to all of its corresponding
# constant tensors. python doesn't have the weak value equivalent
@ -168,7 +168,7 @@ class FakeTensorConverter(object):
def invalidate_constant_aliases(self, tensor):
assert not isinstance(tensor, FakeTensor)
weak_st = StorageWeakRef(tensor.storage())
weak_st = StorageWeakRef(tensor._typed_storage())
if weak_st not in self.constant_storage_mapping:
return
@ -1043,7 +1043,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce
for e in tree_flatten((args, kwargs))[0]:
if isinstance(e, torch.Tensor):
if not e.is_sparse:
storages.add(e.storage()._cdata)
storages.add(e._typed_storage()._cdata)
# TODO: also check metadata change on inputs
# proper aliasing/metadata relationship between outputs and inputs will
@ -1053,7 +1053,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce
if id(e) not in inp_impls and (
isinstance(e, torch.Tensor)
and not e.is_sparse
and e.storage()._cdata in storages
and e._typed_storage()._cdata in storages
):
raise orig_not_implemented_exception

View File

@ -18,12 +18,12 @@ aten = torch.ops.aten
def outputs_alias_inputs(outputs, inputs):
input_storages = {
inp.storage()._cdata
inp._typed_storage()._cdata
for inp in tree_flatten_only(torch.Tensor, inputs)
if torch._C._has_storage(inp)
}
return any(
torch._C._has_storage(out) and out.storage()._cdata in input_storages
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
for out in tree_flatten_only(torch.Tensor, outputs)
)
@ -38,7 +38,7 @@ def output_alias_each_other(outputs):
for out in tree_flatten_only(torch.Tensor, outputs):
if not torch._C._has_storage(out):
continue
stor = out.storage()._cdata
stor = out._typed_storage()._cdata
if stor in storages:
return True
storages.add(stor)

View File

@ -143,7 +143,7 @@ class MetaConverter:
if t.is_sparse:
weak_st = None
else:
weak_st = StorageWeakRef(t.storage())
weak_st = StorageWeakRef(t._typed_storage())
tensor_ref_key = WeakTensorRefKey(t)
def del_ten():
@ -179,13 +179,9 @@ class MetaConverter:
# Use a Weak Ref to s in order to not leak memory
swr = StorageWeakRef(s)
if swr not in self.storage_memo:
self.storage_memo[swr] = (
callback(
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
)
.storage()
.untyped()
)
self.storage_memo[swr] = callback(
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
)._storage()
return self.storage_memo[swr]
# This function assumes that it's possible to do the conversion
@ -362,7 +358,7 @@ class MetaConverter:
# format here
r = r.clone(memory_format=torch.preserve_format)
s = t.storage().untyped()
s = t._storage()
swr = StorageWeakRef(s)
if (
swr not in self.storage_memo
@ -370,7 +366,7 @@ class MetaConverter:
and r.storage_offset() == storage_offset
):
# You're normal and happy, install the fresh storage into the memo
self.storage_memo[swr] = r.storage().untyped()
self.storage_memo[swr] = r._storage()
else:
# You're in crazy town; somehow you gave us a tensor
# that wasn't a view, but had nonzero storage offset,

View File

@ -132,7 +132,7 @@ class Tensor(torch._C._TensorBase):
"different type."
)
else:
new_storage = self.storage().__deepcopy__(memo)
new_storage = self._typed_storage()._deepcopy(memo)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
@ -163,7 +163,9 @@ class Tensor(torch._C._TensorBase):
# need to wrap with TypedStorage
new_tensor = torch._utils._rebuild_qtensor(
torch.storage.TypedStorage(
wrap_storage=new_storage.untyped(), dtype=self.dtype
wrap_storage=new_storage._untyped_storage,
dtype=self.dtype,
_internal=True,
),
self.storage_offset(),
self.size(),
@ -257,7 +259,17 @@ class Tensor(torch._C._TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
torch.storage._warn_typed_storage_removal()
return self._typed_storage()
# For internal use only, to avoid raising deprecation warning
def _typed_storage(self):
_storage = self._storage()
if isinstance(_storage, torch.TypedStorage):
_storage = _storage._untyped_storage
return torch.TypedStorage(
wrap_storage=_storage, dtype=self.dtype, _internal=True
)
def _reduce_ex_internal(self, proto):
check_serializing_named_tensor(self)
@ -331,7 +343,9 @@ class Tensor(torch._C._TensorBase):
# need to wrap with TypedStorage
args_qtensor = (
torch.storage.TypedStorage(
wrap_storage=self.storage().untyped(), dtype=self.dtype
wrap_storage=self._typed_storage()._untyped_storage,
dtype=self.dtype,
_internal=True,
),
self.storage_offset(),
tuple(self.size()),
@ -389,7 +403,9 @@ class Tensor(torch._C._TensorBase):
# need to wrap with TypedStorage
args = (
torch.storage.TypedStorage(
wrap_storage=self.storage().untyped(), dtype=self.dtype
wrap_storage=self._typed_storage()._untyped_storage,
dtype=self.dtype,
_internal=True,
),
self.storage_offset(),
tuple(self.size()),
@ -607,7 +623,7 @@ class Tensor(torch._C._TensorBase):
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.is_shared, (self,), self)
return self.storage().is_shared()
return self._typed_storage()._is_shared()
def share_memory_(self):
r"""Moves the underlying storage to shared memory.
@ -617,7 +633,7 @@ class Tensor(torch._C._TensorBase):
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.share_memory_, (self,), self)
self.storage().share_memory_()
self._typed_storage()._share_memory_()
return self
def __reversed__(self):
@ -1059,7 +1075,9 @@ class Tensor(torch._C._TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage_type, (self,), self)
return self.storage()._get_legacy_storage_class()
torch.storage._warn_typed_storage_removal()
return self._typed_storage()._get_legacy_storage_class()
def refine_names(self, *names):
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.

View File

@ -143,8 +143,8 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
# be a TypedStorage
def _rebuild_tensor(storage, storage_offset, size, stride):
# first construct a tensor with the correct dtype/device
t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
return t.set_(storage.untyped(), storage_offset, size, stride)
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device)
return t.set_(storage._untyped_storage, storage_offset, size, stride)
def _rebuild_tensor_v2(

View File

@ -135,7 +135,7 @@ at::Storage createStorageGetType(
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
Py_DECREF(untyped_storage_obj);

View File

@ -737,11 +737,12 @@ class _CudaBase(object):
__new__ = _lazy_new
from torch.storage import _LegacyStorage
from torch.storage import _LegacyStorage, _warn_typed_storage_removal
class _CudaLegacyStorage(_LegacyStorage):
@classmethod
def from_buffer(cls, *args, **kwargs):
_warn_typed_storage_removal()
raise RuntimeError('from_buffer: Not available for CUDA storage')
@classmethod
@ -755,61 +756,121 @@ class _CudaLegacyStorage(_LegacyStorage):
class ByteStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.uint8
class DoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.double
class FloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.float
class HalfStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.half
class LongStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.long
class IntStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int
class ShortStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.short
class CharStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int8
class BoolStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bool
class BFloat16Storage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bfloat16
class ComplexDoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.cdouble
class ComplexFloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.cfloat
del _LegacyStorage

View File

@ -89,7 +89,7 @@ def find_input_mutations(g):
mutated_inputs = set()
for n in g.nodes:
if n.op == 'placeholder':
inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx)
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == 'call_function':
if n.target is operator.getitem:
@ -109,7 +109,7 @@ def find_input_mutations(g):
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())]
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
# TODO: error on unrecognized nodes
return mutated_inputs

View File

@ -51,7 +51,7 @@ DEFAULT_SUFIX = ".distcp"
def _trim(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach().cpu()
if tensor.storage().size() != tensor.numel():
if tensor._typed_storage()._size() != tensor.numel():
tensor = tensor.clone()
return tensor

View File

@ -1896,7 +1896,7 @@ def all_gather_multigpu(
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696

View File

@ -69,14 +69,14 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
already_allocated = tensor.storage().size() == size.numel()
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor.storage().size()
tensor_storage_size = tensor._typed_storage()._size()
p_assert(
tensor_storage_size == 0,
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
)
tensor.storage().resize_(size.numel())
tensor._typed_storage()._resize_(size.numel())
return not already_allocated
@ -89,23 +89,23 @@ def _free_storage(tensor: torch.Tensor) -> bool:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
already_freed = tensor.storage().size() == 0
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor.storage().size()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor.storage().resize_(0)
tensor._typed_storage()._resize_(0)
return not already_freed
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
"""Returns if ``x`` and ``y`` share the same storage."""
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
return x.storage().data_ptr() == y.storage().data_ptr()
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:

View File

@ -493,7 +493,7 @@ class FlatParamHandle:
flat_param.storage_offset() == 0,
"The `FlatParameter` is not the sole occupant of its storage",
)
orig_storage = flat_param.storage()
orig_storage = flat_param._typed_storage()
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
flat_param, self.rank, self.world_size
)
@ -501,8 +501,8 @@ class FlatParamHandle:
start = sharded_flat_param.numel() * self.rank
end = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
self._init_shard_metadata(numel_padded, start, end)
if orig_storage.size() > 0:
orig_storage.resize_(0)
if orig_storage._size() > 0:
orig_storage._resize_(0)
if self._use_orig_params:
self._use_sharded_views()
@ -838,7 +838,7 @@ class FlatParamHandle:
return False
unsharded_flat_param = self._get_padded_unsharded_flat_param()
already_unsharded = (
unsharded_flat_param.storage().size() == unsharded_flat_param.numel()
unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel()
)
return not already_unsharded
@ -1141,9 +1141,9 @@ class FlatParamHandle:
# the padded unsharded flattened parameter as expected
# NOTE: This check is not strictly needed for correctness but is a
# useful sanity check since the tensor should only be used internally.
unpadded_storage_ptr = self.flat_param.storage().data_ptr()
unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr()
padded_storage_ptr = (
self._get_padded_unsharded_flat_param().storage().data_ptr()
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
)
p_assert(
unpadded_storage_ptr == padded_storage_ptr,
@ -1824,7 +1824,7 @@ class FlatParamHandle:
@staticmethod
def _check_storage_freed(tensor: Tensor):
storage_size: int = tensor.storage().size()
storage_size: int = tensor._typed_storage()._size()
p_assert(
storage_size == 0,
f"Expects storage to be freed but got storage with size {storage_size}",
@ -1832,7 +1832,7 @@ class FlatParamHandle:
@staticmethod
def _check_storage_allocated(tensor: Tensor):
storage_size: int = tensor.storage().size()
storage_size: int = tensor._typed_storage()._size()
p_assert(storage_size > 0, "Expects storage to be allocated")
def _check_low_precision_shard(self):

View File

@ -107,7 +107,7 @@ def profile_sizes(
latent_size = memory_after - memory_before
# Analyze size of parameters.
param_size = sum(p.storage().nbytes() for p in layer.parameters())
param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters())
# Combine size of parameters and activations with normalize scales.
size = latent_size * latent_scale + param_size * param_scale

View File

@ -104,7 +104,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
#
# Issue: https://github.com/pytorch/pytorch/issues/27366
#
tensor = tensor.new_empty([0]).set_(tensor.storage())
tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
# Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]

View File

@ -100,8 +100,8 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
# Assert here that this is actually the case, and their storages are the same.
assert isinstance(node.meta['fake_result'], FakeTensor)
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
view_storage = StorageWeakRef(node.meta['fake_result'].storage())
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage())
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
assert view_storage == base_storage
return result
@ -176,7 +176,7 @@ _VIEW_INVERSE_MAP = {
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
def _add_if_tensor(x, set_):
if isinstance(x, FakeTensor):
set_.add(StorageWeakRef(x.storage()))
set_.add(StorageWeakRef(x._typed_storage()))
nodes_used_after = set()
for t in tensor_aliases:
@ -452,7 +452,7 @@ def reinplace(gm, *sample_args):
# Useful debug printing
# def _print(x):
# if isinstance(x, FakeTensor):
# print(f'fake_result: {StorageWeakRef(x.storage()).cdata}')
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
# for n in gm.graph.nodes:
# print(n.format_node())
@ -468,7 +468,10 @@ def reinplace(gm, *sample_args):
# so we know not to re-inplace them.
# NOTE: later, we'll need to add an optimization for fully recovering performance
# on programs that mutate inputs.
input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder')
input_storages = set(
StorageWeakRef(
node.meta['fake_result']._typed_storage()
) for node in gm.graph.nodes if node.op == 'placeholder')
# We also need to know for a given node, what are all of its aliasing nodes.
@ -478,7 +481,7 @@ def reinplace(gm, *sample_args):
# Tree-mapping because some ops can return lists of tensors.
def _add_to_map(x):
if isinstance(x, FakeTensor):
storage_to_nodes[StorageWeakRef(x.storage())].add(n)
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
tree_map(_add_to_map, n.meta['fake_result'])
# inplace-ify functional ops, subject to the constraints written below.
@ -529,7 +532,7 @@ def reinplace(gm, *sample_args):
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
self_arg_name = self_arg.name
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
if self_arg_storage in input_storages:
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
continue
@ -539,7 +542,7 @@ def reinplace(gm, *sample_args):
# so we prevent re-inplacing in this case.
continue
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
self_aliases = storage_to_nodes[self_arg_storage]
# First, we find all later usages of any of the aliases of self_arg.
@ -594,7 +597,7 @@ def reinplace(gm, *sample_args):
# Hmm... morally I think we also want to keep the `fake_result` metadata
# up to date here, but I'm not sure how easy it is to do.
# Maybe it's fine to wait until the end of the pass to update it.
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
@ -624,8 +627,14 @@ def reinplace(gm, *sample_args):
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor))
node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor))
old_res_storage = set(
StorageWeakRef(
x._typed_storage()
) for x in old_flattened_res if isinstance(x, FakeTensor))
node_res_storage = set(
StorageWeakRef(
x._typed_storage()
) for x in node_flattened_res if isinstance(x, FakeTensor))
# This will happen if we're updating a view op, e.g.
# e.g. replacing
@ -639,7 +648,10 @@ def reinplace(gm, *sample_args):
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
new_flattened_res, _ = tree_flatten(new.meta['fake_result'])
new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor))
new_res_storage = set(
StorageWeakRef(
x._typed_storage()
) for x in new_flattened_res if isinstance(x, FakeTensor))
assert len(new_res_storage) == 1
(old_ref,) = old_res_storage
(new_ref,) = new_res_storage

View File

@ -113,7 +113,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
# If storage_handle is None, storage points to nullptr.
if storage_handle is None or storage_size_bytes == 0:
storage = storage_cls(0, dtype=dtype, device=storage_device)
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
else:
storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
if storage is None:
@ -132,8 +132,10 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
# We already ref counting this Storage, but producer needs new ref-counters to be released.
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
_storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage
t = torch._utils._rebuild_tensor(
torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype),
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
@ -147,7 +149,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
def reduce_tensor(tensor):
storage = tensor.storage()
storage = tensor._typed_storage()
if tensor.requires_grad and not tensor.is_leaf:
raise RuntimeError("Cowardly refusing to serialize non-leaf tensor which requires_grad, "
@ -248,7 +250,7 @@ def reduce_tensor(tensor):
# eliminated it so that we could just use tensor views to implement the same
# thing.
#
if storage.is_cuda:
if storage._untyped_storage.device.type == 'cuda':
(device,
handle,
storage_size_bytes,
@ -325,7 +327,8 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
storage = torch.TypedStorage(
wrap_storage=untyped_storage,
dtype=dtype)
dtype=dtype,
_internal=True)
shared_cache[handle] = StorageWeakRef(storage)
return storage._shared_decref()
@ -334,18 +337,18 @@ def rebuild_storage_empty(cls):
return cls()
def rebuild_typed_storage(storage, dtype):
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype)
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
# Use for torch.storage.TypedStorage
def reduce_typed_storage(storage):
return (rebuild_typed_storage, (storage._storage, storage.dtype))
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
def rebuild_typed_storage_child(storage, storage_type):
return storage_type(wrap_storage=storage)
return storage_type(wrap_storage=storage, _internal=True)
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
def reduce_typed_storage_child(storage):
return (rebuild_typed_storage_child, (storage._storage, type(storage)))
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
def reduce_storage(storage):
from . import get_sharing_strategy

View File

@ -273,6 +273,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor.to_sparse_csc,
Tensor.to_sparse_bsr,
Tensor.to_sparse_bsc,
Tensor._typed_storage,
Tensor._reduce_ex_internal,
Tensor._fix_weakref,
Tensor._make_wrapper_subclass,

View File

@ -887,7 +887,7 @@ class PackageExporter:
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
untyped_storage = obj._storage
untyped_storage = obj._untyped_storage
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj.size()

View File

@ -208,14 +208,14 @@ class PackageImporter(Importer):
name = f"{key}.storage"
if storage_context.has_storage(name):
storage = storage_context.get_storage(name, dtype).storage()
storage = storage_context.get_storage(name, dtype)._typed_storage()
else:
tensor = self.zip_reader.get_storage_from_record(
".data/" + name, size, dtype
)
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
storage_context.add_storage(name, tensor)
storage = tensor.storage()
storage = tensor._typed_storage()
loaded_storages[key] = restore_location(storage, location)
def persistent_load(saved_id):
@ -239,7 +239,7 @@ class PackageImporter(Importer):
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
)
elif typename == "reduce_package":
# to fix BC breaking change, objects on this load path

View File

@ -469,12 +469,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
storage_numel = obj._size()
elif isinstance(obj, torch.UntypedStorage):
storage = obj
@ -597,11 +597,11 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj.size()
storage_numel = obj._size()
else:
storage = obj
@ -893,14 +893,15 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type.dtype
dtype = storage_type._dtype
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
wrap_storage=obj,
dtype=dtype)
dtype=dtype,
_internal=True)
storage_views = pickle_module.load(f, **pickle_load_args)
for target_cdata, root_cdata, offset, numel in storage_views:
@ -910,8 +911,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype)
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype,
_internal=True)
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
@ -927,7 +929,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
tensor = torch.tensor([], dtype=storage.dtype).set_(
storage._storage, storage_offset, numel, stride)
storage._untyped_storage, storage_offset, numel, stride)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
@ -962,7 +964,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
# stop wrapping with TypedStorage
deserialized_objects[root_key] = torch.storage.TypedStorage(
wrap_storage=restore_location(obj, location),
dtype=dtype)
dtype=dtype,
_internal=True)
typed_storage = deserialized_objects[root_key]
if view_metadata is not None:
@ -973,8 +976,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[view_key] = torch.storage.TypedStorage(
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype)
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype,
_internal=True)
res = deserialized_objects[view_key]
else:
@ -1023,7 +1027,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
for key in deserialized_storage_keys:
assert key in deserialized_objects
typed_storage = deserialized_objects[key]
typed_storage._storage._set_from_file(
typed_storage._untyped_storage._set_from_file(
f, offset, f_should_read_directly,
torch._utils._element_size(typed_storage.dtype))
if offset is not None:
@ -1082,12 +1086,13 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped()
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
loaded_storages[key] = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location),
dtype=dtype)
dtype=dtype,
_internal=True)
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)

View File

@ -7,6 +7,7 @@ from typing import Any, TypeVar, Type, Union, cast
import copy
import collections
from functools import lru_cache
import warnings
try:
import numpy as np
HAS_NUMPY = True
@ -131,7 +132,7 @@ class _StorageBase(object):
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
if storage.data_ptr() == self.data_ptr():
storage = storage.clone()
return storage
@ -297,7 +298,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
dtype=dtype,
device=device)
return tmp_tensor.storage().untyped()
return tmp_tensor._typed_storage()._untyped_storage
def _isint(x):
if HAS_NUMPY:
@ -305,16 +306,32 @@ def _isint(x):
else:
return isinstance(x, int)
def _warn_typed_storage_removal():
message = (
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly."
)
warnings.warn(message, UserWarning)
class TypedStorage:
is_sparse = False
dtype: torch.dtype
@property
def _dtype(self):
return self.dtype
def fill_(self, value):
self[0:len(self)] = value
_warn_typed_storage_removal()
self._setitem(slice(0, self._size()), value)
return self
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=False):
if not _internal:
_warn_typed_storage_removal()
if cls == torch.storage._LegacyStorage:
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
@ -353,8 +370,9 @@ class TypedStorage:
return TypedStorage(
*args,
dtype=cls.dtype,
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu')
dtype=cls._dtype,
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
_internal=True)
else:
if len(args) != 0:
@ -379,9 +397,12 @@ class TypedStorage:
return TypedStorage(
*args,
wrap_storage=wrap_storage,
dtype=cls.dtype)
dtype=cls.dtype,
_internal=True)
def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=False):
if not _internal:
_warn_typed_storage_removal()
arg_error_msg = (
'TypedStorage.__init__ received an invalid combination '
'of arguments. Expected one of:\n'
@ -419,7 +440,7 @@ class TypedStorage:
arg_error_msg +
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
self._storage = wrap_storage
self._untyped_storage = wrap_storage
else:
self.dtype = torch.get_default_dtype() if dtype is None else dtype
@ -430,13 +451,13 @@ class TypedStorage:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
if len(args) == 0:
self._storage = torch.UntypedStorage(device=device)
self._untyped_storage = torch.UntypedStorage(device=device)
elif len(args) == 1:
if _isint(args[0]):
self._storage = torch.UntypedStorage(int(args[0]) * self.element_size(), device=device)
self._untyped_storage = torch.UntypedStorage(int(args[0]) * self._element_size(), device=device)
elif isinstance(args[0], collections.abc.Sequence):
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
self._untyped_storage = _get_storage_from_sequence(args[0], self.dtype, device)
else:
raise TypeError(
arg_error_msg +
@ -447,30 +468,35 @@ class TypedStorage:
arg_error_msg +
"\nToo many positional arguments")
@property
def is_cuda(self):
_warn_typed_storage_removal()
return self.device.type == 'cuda'
def untyped(self):
"""Returns the internal :class:`torch.UntypedStorage`"""
return self._storage
_warn_typed_storage_removal()
return self._untyped_storage
def _new_wrapped_storage(self, untyped_storage):
assert type(untyped_storage) == torch.UntypedStorage
if type(self) == TypedStorage:
return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
return TypedStorage(
wrap_storage=untyped_storage,
dtype=self.dtype,
_internal=True)
else:
return type(self)(wrap_storage=untyped_storage)
def __len__(self):
return self._storage.nbytes() // self.element_size()
_warn_typed_storage_removal()
return self._size()
def _maybe_wrap_index(self, idx, is_stop=False):
if idx is None:
if is_stop:
return self.size()
return self._size()
else:
return 0
@ -479,20 +505,24 @@ class TypedStorage:
raise TypeError(
f"can't index a {type(self)} with {type(idx)}")
if is_stop:
if (idx > self.size()) or (idx < -self.size()):
if (idx > self._size()) or (idx < -self._size()):
raise IndexError(
f'index {idx} out of range for storage of size {self.size()}')
if idx > 0:
return idx
else:
return idx % self.size()
return idx % self._size()
else:
if (idx >= self.size()) or (idx < -self.size()):
if (idx >= self._size()) or (idx < -self._size()):
raise IndexError(
f'index {idx} out of range for storage of size {self.size()}')
return idx % self.size()
return idx % self._size()
def __setitem__(self, idx, value):
_warn_typed_storage_removal()
return self._setitem(idx, value)
def _setitem(self, idx, value):
if not isinstance(idx, (int, slice)):
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
if torch.is_storage(value):
@ -506,16 +536,22 @@ class TypedStorage:
torch.qint8: torch.int8
}
tmp_dtype = interpret_dtypes[self.dtype]
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
wrap_storage=self._storage,
dtype=tmp_dtype))
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device)
tmp_tensor.set_(TypedStorage(
wrap_storage=self._untyped_storage,
dtype=tmp_dtype,
_internal=True))
else:
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
tmp_tensor[idx] = value
def __getitem__(self, idx):
if self.device.type == 'meta':
_warn_typed_storage_removal()
return self._getitem(idx)
def _getitem(self, idx):
if self._untyped_storage.device.type == 'meta':
raise NotImplementedError("Not available for 'meta' device type")
# NOTE: Before TypedStorage existed, indexing with a slice used to be
@ -536,21 +572,32 @@ class TypedStorage:
torch.qint8: torch.int8
}
return TypedStorage(
wrap_storage=self._storage,
dtype=interpret_dtypes[self.dtype])[idx]
wrap_storage=self._untyped_storage,
dtype=interpret_dtypes[self.dtype],
_internal=True)._getitem(idx)
idx_wrapped = self._maybe_wrap_index(idx)
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
return tmp_tensor[idx_wrapped].item()
def copy_(self, source: T, non_blocking: bool = None):
self._storage.copy_(source.untyped(), non_blocking)
_warn_typed_storage_removal()
if isinstance(source, TypedStorage):
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
else:
self._untyped_storage.copy_(source, non_blocking)
return self
def nbytes(self):
return self._storage.nbytes()
_warn_typed_storage_removal()
return self._nbytes()
# For internal use only, to avoid deprecation warning
def _nbytes(self):
return self._untyped_storage.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
_warn_typed_storage_removal()
if dtype is None:
legacy_class = self._get_legacy_storage_class()
@ -560,21 +607,29 @@ class TypedStorage:
return '.'.join([self.__module__, type(self).__name__])
else:
return self._storage.type(dtype, non_blocking)
return self._untyped_storage.type(dtype, non_blocking)
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
def element_size(self):
_warn_typed_storage_removal()
return self._element_size()
# For internal use only, to avoid deprecation warning
def _element_size(self):
return torch._utils._element_size(self.dtype)
def get_device(self) -> int:
return self._storage.get_device()
_warn_typed_storage_removal()
return self._untyped_storage.get_device()
def __str__(self):
_warn_typed_storage_removal()
info_str = (
f'[{torch.typename(self)}(dtype={self.dtype}, '
f'device={self.device}) of size {len(self)}]')
@ -585,35 +640,48 @@ class TypedStorage:
return data_str + '\n' + info_str
def __repr__(self):
_warn_typed_storage_removal()
return str(self)
def __iter__(self):
_warn_typed_storage_removal()
return iter(map(lambda i: self[i], range(self.size())))
def __copy__(self):
return self._new_wrapped_storage(copy.copy(self._storage))
_warn_typed_storage_removal()
return self._new_wrapped_storage(copy.copy(self._untyped_storage))
def __deepcopy__(self, memo):
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
_warn_typed_storage_removal()
return self._deepcopy(memo)
# For internal use only, to avoid deprecation warning
def _deepcopy(self, memo):
return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
def __sizeof__(self):
_warn_typed_storage_removal()
return super(TypedStorage, self).__sizeof__() + self.nbytes()
def clone(self):
"""Returns a copy of this storage"""
return self._new_wrapped_storage(self._storage.clone())
_warn_typed_storage_removal()
return self._new_wrapped_storage(self._untyped_storage.clone())
def tolist(self):
"""Returns a list containing the elements of this storage"""
_warn_typed_storage_removal()
return list(self)
def cpu(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
return self._new_wrapped_storage(self._storage.cpu())
_warn_typed_storage_removal()
return self._new_wrapped_storage(self._untyped_storage.cpu())
def pin_memory(self):
"""Coppies the storage to pinned memory, if it's not already pinned."""
return self._new_wrapped_storage(self._storage.pin_memory())
_warn_typed_storage_removal()
return self._new_wrapped_storage(self._untyped_storage.pin_memory())
def share_memory_(self):
"""Moves the storage to shared memory.
@ -624,7 +692,12 @@ class TypedStorage:
Returns: self
"""
self._storage.share_memory_()
_warn_typed_storage_removal()
return self._share_memory_()
# For internal use only, to avoid deprecation warning
def _share_memory_(self):
self._untyped_storage.share_memory_()
return self
def _new_shared(self, size, *, device=None):
@ -632,25 +705,37 @@ class TypedStorage:
if device is None:
device = 'cpu'
device = torch.device(device)
untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device)
untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device)
return TypedStorage(
wrap_storage=untyped_storage,
dtype=self.dtype)
dtype=self.dtype,
_internal=True)
@property
def _cdata(self):
return self._storage._cdata
return self._untyped_storage._cdata
@property
def device(self):
return self._storage.device
_warn_typed_storage_removal()
return self._untyped_storage.device
def size(self):
_warn_typed_storage_removal()
return self._size()
# For internal use only, to avoid deprecation warning
def _size(self):
# NB: don't indirect through __len__, as that requires
# an int to be returned
return self.nbytes() // self.element_size()
return self._untyped_storage.nbytes() // self._element_size()
def pickle_storage_type(self):
_warn_typed_storage_removal()
return self._pickle_storage_type()
# For internal use only, to avoid deprecation warning
def _pickle_storage_type(self):
try:
return _dtype_to_storage_type_map()[self.dtype]
except KeyError:
@ -662,20 +747,35 @@ class TypedStorage:
return (_load_from_bytes, (b.getvalue(),))
def data_ptr(self):
return self._storage.data_ptr()
_warn_typed_storage_removal()
return self._data_ptr()
# For internal use only, to avoid deprecation warning
def _data_ptr(self):
return self._untyped_storage.data_ptr()
def resize_(self, size):
self._storage.resize_(size * self.element_size())
_warn_typed_storage_removal()
self._resize_(size)
# For internal use only, to avoid deprecation warning
def _resize_(self, size):
self._untyped_storage.resize_(size * self._element_size())
@classmethod
def _free_weak_ref(cls, *args, **kwargs):
return UntypedStorage._free_weak_ref(*args, **kwargs)
def _weak_ref(self, *args, **kwargs):
return self._storage._weak_ref(*args, **kwargs)
return self._untyped_storage._weak_ref(*args, **kwargs)
@classmethod
def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
def from_buffer(cls, *args, **kwargs):
_warn_typed_storage_removal()
return cls._from_buffer(*args, **kwargs)
@classmethod
def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
if cls == TypedStorage:
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device('cpu' if device is None else device)
@ -693,65 +793,80 @@ class TypedStorage:
"from_buffer: 'device' can only be specified in "
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
dtype = cls.dtype
dtype = cls._dtype
untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
return TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
return TypedStorage(
wrap_storage=untyped_storage,
dtype=dtype,
_internal=True)
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage()
if storage.data_ptr() == self.data_ptr():
storage = storage.clone()
return storage
def double(self):
"""Casts this storage to double type"""
_warn_typed_storage_removal()
return self._to(torch.double)
def float(self):
"""Casts this storage to float type"""
_warn_typed_storage_removal()
return self._to(torch.float)
def half(self):
"""Casts this storage to half type"""
_warn_typed_storage_removal()
return self._to(torch.half)
def long(self):
"""Casts this storage to long type"""
_warn_typed_storage_removal()
return self._to(torch.long)
def int(self):
"""Casts this storage to int type"""
_warn_typed_storage_removal()
return self._to(torch.int)
def short(self):
"""Casts this storage to short type"""
_warn_typed_storage_removal()
return self._to(torch.short)
def char(self):
"""Casts this storage to char type"""
_warn_typed_storage_removal()
return self._to(torch.int8)
def byte(self):
"""Casts this storage to byte type"""
_warn_typed_storage_removal()
return self._to(torch.uint8)
def bool(self):
"""Casts this storage to bool type"""
_warn_typed_storage_removal()
return self._to(torch.bool)
def bfloat16(self):
"""Casts this storage to bfloat16 type"""
_warn_typed_storage_removal()
return self._to(torch.bfloat16)
def complex_double(self):
"""Casts this storage to complex double type"""
_warn_typed_storage_removal()
return self._to(torch.cdouble)
def complex_float(self):
"""Casts this storage to complex float type"""
_warn_typed_storage_removal()
return self._to(torch.cfloat)
@classmethod
@ -773,6 +888,7 @@ class TypedStorage:
shared (bool): whether to share memory
size (int): number of elements in the storage
"""
_warn_typed_storage_removal()
if cls == TypedStorage:
raise RuntimeError('from_file can only be called on derived classes')
untyped_storage: UntypedStorage = UntypedStorage.from_file(
@ -787,33 +903,39 @@ class TypedStorage:
return UntypedStorage._expired(*args, **kwargs)
def is_pinned(self):
return self._storage.is_pinned()
_warn_typed_storage_removal()
return self._untyped_storage.is_pinned()
def _write_file(self, *args, **kwargs):
return self._storage._write_file(*args, **kwargs)
return self._untyped_storage._write_file(*args, **kwargs)
def _set_from_file(self, *args, **kwargs):
return self._storage._set_from_file(*args, **kwargs)
return self._untyped_storage._set_from_file(*args, **kwargs)
def _set_cdata(self, *args, **kwargs):
return self._storage._set_cdata(*args, **kwargs)
return self._untyped_storage._set_cdata(*args, **kwargs)
def _share_cuda_(self, *args, **kwargs):
return self._storage._share_cuda_(*args, **kwargs)
return self._untyped_storage._share_cuda_(*args, **kwargs)
def is_shared(self):
return self._storage.is_shared()
_warn_typed_storage_removal()
return self._is_shared()
# For internal use only, to avoid deprecation warning
def _is_shared(self):
return self._untyped_storage.is_shared()
@classmethod
def _new_shared_cuda(cls, *args, **kwargs):
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
def _share_filename_cpu_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
return manager_handle, storage_handle, size // self.element_size()
manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
return manager_handle, storage_handle, size // self._element_size()
def _shared_decref(self):
self._storage._shared_decref()
self._untyped_storage._shared_decref()
return self
@classmethod
@ -821,11 +943,11 @@ class TypedStorage:
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
def _shared_incref(self, *args, **kwargs):
return self._storage._shared_incref(*args, **kwargs)
return self._untyped_storage._shared_incref(*args, **kwargs)
def _share_fd_cpu_(self, *args, **kwargs):
fd, size = self._storage._share_fd_cpu_(*args, **kwargs)
return fd, size // self.element_size()
fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
return fd, size // self._element_size()
def _get_legacy_storage_class(self):
if self.dtype not in _dtype_to_storage_type_map():
@ -859,7 +981,7 @@ class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
@classmethod
def _new_shared(cls, size):
"""Creates a new storage in shared memory with the same data type"""
untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size())
untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
return cls(wrap_storage=untyped_storage)
@classmethod

View File

@ -927,9 +927,34 @@ def originate_pairs(
Returns:
(List[Pair]): Originated pairs.
"""
if (
isinstance(actual, torch.TypedStorage)
and isinstance(expected, torch.TypedStorage)
):
actual_len = actual._size()
expected_len = expected._size()
if actual_len != expected_len:
raise ErrorMeta(
AssertionError, f"The length of the sequences mismatch: {actual_len} != {expected_len}", id=id
)
pairs = []
for idx in range(actual_len):
pairs.extend(
originate_pairs(
actual._getitem(idx),
expected._getitem(idx),
pair_types=pair_types,
sequence_types=sequence_types,
mapping_types=mapping_types,
id=(*id, idx),
**options,
)
)
return pairs
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if (
elif (
isinstance(actual, sequence_types)
and not isinstance(actual, str)
and isinstance(expected, sequence_types)

View File

@ -47,7 +47,7 @@ class SchemaCheckMode(TorchDispatchMode):
before.size() == after.size() and
torch.allclose(before, after, equal_nan=True) and
md[0] == after.stride() and
md[1] == after.storage()._cdata
md[1] == after._typed_storage()._cdata
)
return False
@ -76,12 +76,12 @@ class SchemaCheckMode(TorchDispatchMode):
if not type(e) == torch.Tensor:
try:
current = e.elem
return (deepcopy(current.stride()), current.storage()._cdata)
return (deepcopy(current.stride()), current._typed_storage()._cdata)
except AttributeError as t:
return None
# Sparse CSR tensors do not have strides or storage
elif (e.layout != torch.sparse_csr):
return (deepcopy(e.stride()), e.storage()._cdata)
return (deepcopy(e.stride()), e._typed_storage()._cdata)
return None
self.ops.append(func._schema.name)

View File

@ -391,7 +391,7 @@ def _inflate_expr(
if isinstance(arg, torch.Tensor):
# Small-storage tensors can just be saved directly.
if arg.storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
return arg, ref, None
# Small contiguous tensors can be cloned to have small storage.
# TODO: Should we do this even for non-contiguous tensors?
@ -407,7 +407,7 @@ def _inflate_expr(
# TODO: Provide more useful diagnostics.
raise Exception(
f"Bundled input argument at position '{ref}' is "
f"a tensor with storage size {arg.storage().size()}. "
f"a tensor with storage size {arg._typed_storage().size()}. "
f"You probably don't want to bundle this as an input. "
)
else:

View File

@ -158,7 +158,7 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel, device=elem.device)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)