mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
53ca5ad347
commit
ee28b865ee
@ -22,6 +22,10 @@ holds the data as an untyped array of bytes.
|
||||
Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`,
|
||||
which stores all of the data that the :class:`torch.Tensor` views.
|
||||
|
||||
.. warning::
|
||||
All storage classes except for :class:`torch.UntypedStorage` will be removed
|
||||
in the future, and :class:`torch.UntypedStorage` will be used in all cases.
|
||||
|
||||
.. autoclass:: torch.TypedStorage
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
@ -6805,8 +6805,8 @@ for shape in [(1,), ()]:
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
|
||||
a = torch.ones(5, requires_grad=True)
|
||||
|
||||
warnings.simplefilter('always')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
y = a * a
|
||||
# should raise two warnings from a being saved twice
|
||||
self.assertEqual(len(w), 2)
|
||||
|
||||
@ -595,7 +595,7 @@ class TestCuda(TestCase):
|
||||
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
|
||||
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
|
||||
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
|
||||
q_copy[1].fill_(10)
|
||||
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
|
||||
@ -6470,6 +6470,127 @@ class TestTorch(TestCase):
|
||||
self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage')
|
||||
self.assertIs(complexdouble_storage.dtype, torch.complex128)
|
||||
|
||||
# Test that internal versions of functions related to TypedStorage do not
|
||||
# produce a deprecation warning
|
||||
def test_typed_storage_internal_no_warning(self):
|
||||
s0 = torch.FloatStorage(10)
|
||||
s0_untyped = s0.untyped()
|
||||
t0 = torch.randn(10)
|
||||
|
||||
funcs = [
|
||||
lambda: torch.FloatStorage(_internal=True),
|
||||
lambda: torch.TypedStorage(
|
||||
dtype=torch.float,
|
||||
device='cpu',
|
||||
_internal=True),
|
||||
lambda: torch.TypedStorage(
|
||||
wrap_storage=s0_untyped,
|
||||
dtype=s0.dtype,
|
||||
_internal=True),
|
||||
lambda: torch.FloatStorage._dtype,
|
||||
lambda: s0._resize_(20),
|
||||
lambda: s0._size(),
|
||||
lambda: s0._untyped_storage,
|
||||
lambda: s0._is_shared(),
|
||||
lambda: s0._share_memory_(),
|
||||
lambda: s0._pickle_storage_type(),
|
||||
lambda: s0._setitem(slice(0, s0._size()), 1),
|
||||
lambda: s0._element_size(),
|
||||
lambda: s0._deepcopy({}),
|
||||
lambda: s0._data_ptr(),
|
||||
lambda: s0._nbytes(),
|
||||
lambda: t0._typed_storage(),
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
s1 = torch.cuda.FloatStorage(10)
|
||||
s1_untyped = s1.untyped()
|
||||
t1 = torch.randn(10, device='cuda')
|
||||
|
||||
funcs += [
|
||||
lambda: torch.cuda.FloatStorage(_internal=True),
|
||||
lambda: torch.TypedStorage(
|
||||
dtype=torch.float,
|
||||
device='cuda',
|
||||
_internal=True),
|
||||
lambda: torch.TypedStorage(
|
||||
wrap_storage=s1_untyped,
|
||||
dtype=s1.dtype,
|
||||
_internal=True),
|
||||
lambda: torch.cuda.FloatStorage._dtype,
|
||||
lambda: s1._resize_(20),
|
||||
lambda: s1._size(),
|
||||
lambda: s1._untyped_storage,
|
||||
lambda: s1._is_shared(),
|
||||
lambda: s1._share_memory_(),
|
||||
lambda: s1._pickle_storage_type(),
|
||||
lambda: s1._setitem(slice(0, s1._size()), 1),
|
||||
lambda: s1._element_size(),
|
||||
lambda: s1._deepcopy({}),
|
||||
lambda: s1._data_ptr(),
|
||||
lambda: s1._nbytes(),
|
||||
lambda: t1._typed_storage(),
|
||||
]
|
||||
|
||||
# Check that each of the TypedStorage internal function calls do not
|
||||
# produce a deprecation warning
|
||||
for f in funcs:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('error', "TypedStorage is deprecated")
|
||||
f()
|
||||
|
||||
# Test that public functions related to TypedStorage produce a deprecation
|
||||
# warning
|
||||
def test_typed_storage_deprecation_warning(self):
|
||||
s0 = torch.FloatStorage(10)
|
||||
funcs = [
|
||||
lambda: torch.FloatStorage(),
|
||||
lambda: torch.FloatStorage.dtype,
|
||||
lambda: s0.fill_(0),
|
||||
lambda: s0.is_cuda,
|
||||
lambda: s0.untyped(),
|
||||
lambda: len(s0),
|
||||
lambda: s0[0],
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
s1 = torch.cuda.FloatStorage(10)
|
||||
funcs += [
|
||||
lambda: torch.cuda.FloatStorage(),
|
||||
lambda: torch.cuda.FloatStorage.dtype,
|
||||
lambda: s1.fill_(0),
|
||||
lambda: s1.is_cuda,
|
||||
lambda: s1.untyped(),
|
||||
lambda: len(s1),
|
||||
lambda: s1[0],
|
||||
]
|
||||
|
||||
# Check that each of the TypedStorage function calls produce a warning
|
||||
# if warnings are reset between each
|
||||
for f in funcs:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.resetwarnings()
|
||||
f()
|
||||
self.assertEqual(len(w), 1)
|
||||
warning = w[0].message
|
||||
self.assertTrue(warning, DeprecationWarning)
|
||||
self.assertTrue(re.search(
|
||||
'^TypedStorage is deprecated',
|
||||
str(warning)))
|
||||
|
||||
# Check that only one warning is raised from calling multiple
|
||||
# TypedStorage functions if warnings are not reset between each
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.resetwarnings()
|
||||
for f in funcs:
|
||||
f()
|
||||
self.assertEqual(len(w), 1)
|
||||
warning = w[0].message
|
||||
self.assertTrue(warning, DeprecationWarning)
|
||||
self.assertTrue(re.search(
|
||||
'^TypedStorage is deprecated',
|
||||
str(warning)))
|
||||
|
||||
def test_from_file(self):
|
||||
def assert_with_filename(filename):
|
||||
size = 10000
|
||||
|
||||
@ -102,7 +102,7 @@ class TestViewOps(TestCase):
|
||||
# Note: only validates storage on native device types
|
||||
# because some accelerators, like XLA, do not expose storage
|
||||
if base.device.type == 'cpu' or base.device.type == 'cuda':
|
||||
if base.storage().data_ptr() != other.storage().data_ptr():
|
||||
if base._storage().data_ptr() != other._storage().data_ptr():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -979,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function(self)) {
|
||||
return handle_torch_function(self, "storage");
|
||||
return handle_torch_function(self, "_storage");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return createPyObject(self_.storage());
|
||||
|
||||
@ -709,7 +709,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
|
||||
################################################################################
|
||||
|
||||
from ._tensor import Tensor
|
||||
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
|
||||
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
|
||||
|
||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||
# dtype, use torch.storage.TypedStorage directly.
|
||||
@ -717,86 +717,171 @@ from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage
|
||||
class ByteStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.uint8
|
||||
|
||||
class DoubleStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.double
|
||||
|
||||
class FloatStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.float
|
||||
|
||||
class HalfStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.half
|
||||
|
||||
class LongStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.long
|
||||
|
||||
class IntStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.int
|
||||
|
||||
class ShortStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.short
|
||||
|
||||
class CharStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.int8
|
||||
|
||||
class BoolStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.bool
|
||||
|
||||
class BFloat16Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
class ComplexDoubleStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.cdouble
|
||||
|
||||
class ComplexFloatStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.cfloat
|
||||
|
||||
class QUInt8Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.quint8
|
||||
|
||||
class QInt8Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.qint8
|
||||
|
||||
class QInt32Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.qint32
|
||||
|
||||
class QUInt4x2Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.quint4x2
|
||||
|
||||
class QUInt2x4Storage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.quint2x4
|
||||
|
||||
_storage_classes = {
|
||||
|
||||
@ -23,7 +23,7 @@ def _save_storages(importer, obj):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
storage = obj._storage
|
||||
storage = obj._untyped_storage
|
||||
dtype = obj.dtype
|
||||
else:
|
||||
storage = obj
|
||||
|
||||
@ -27,7 +27,7 @@ class ShapeAliasingAndMutationProp(ShapeProp):
|
||||
|
||||
def tensor_alias_group(self, value: torch.Tensor):
|
||||
"""Assign a unique identifier to the storage of a given tensor"""
|
||||
storage = StorageWeakRef(value.storage())
|
||||
storage = StorageWeakRef(value._typed_storage())
|
||||
alias_group = self.storage_to_alias_group.get(storage)
|
||||
if alias_group is None:
|
||||
alias_group = next(self.make_alias_group)
|
||||
|
||||
@ -157,7 +157,7 @@ class DDPOptimizer:
|
||||
for name, p in target.named_parameters():
|
||||
param = target.get_parameter(name)
|
||||
if p.requires_grad and not self._ignore_parameter(param):
|
||||
buckets[0].size += p.storage().nbytes()
|
||||
buckets[0].size += p._storage().nbytes()
|
||||
buckets[0].params.append(f"{node.target}_{name}")
|
||||
buckets[0].param_ids.append(id(param))
|
||||
elif node.op == "get_attr":
|
||||
@ -165,7 +165,7 @@ class DDPOptimizer:
|
||||
if maybe_param.requires_grad and not self._ignore_parameter(
|
||||
maybe_param
|
||||
):
|
||||
buckets[0].size += maybe_param.storage().nbytes()
|
||||
buckets[0].size += maybe_param._storage().nbytes()
|
||||
buckets[0].params.append(node.target)
|
||||
buckets[0].param_ids.append(id(maybe_param))
|
||||
|
||||
|
||||
@ -381,7 +381,7 @@ def find_input_mutations(g):
|
||||
mutated_inputs = set()
|
||||
for n in g.nodes:
|
||||
if n.op == "placeholder":
|
||||
inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx)
|
||||
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
|
||||
input_idx += 1
|
||||
elif n.op == "call_function":
|
||||
if n.target is operator.getitem:
|
||||
@ -402,7 +402,7 @@ def find_input_mutations(g):
|
||||
# TODO: not correct for args that contain tensors in a struct
|
||||
# like list
|
||||
mutated_inputs |= inputs[
|
||||
StorageWeakRef(meta_fk(argument.meta).storage())
|
||||
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
|
||||
]
|
||||
# TODO: error on unrecognized nodes
|
||||
return mutated_inputs
|
||||
|
||||
@ -1158,7 +1158,9 @@ def _as_strided_meta(
|
||||
# as_strided to shapes with no elements are trivially valid, so it's OK
|
||||
pass
|
||||
elif isinstance(a, torch.Tensor):
|
||||
utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset)
|
||||
utils.check_in_bounds_for_storage(
|
||||
a._typed_storage(), size, stride, storage_offset
|
||||
)
|
||||
|
||||
return TensorMeta(a, shape=size, strides=stride)
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ class FakeTensorConverter(object):
|
||||
# const_tensor.add_(torch.rand([1]))
|
||||
# all aliases of it must become no longer const
|
||||
assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None
|
||||
weak_st = StorageWeakRef(fake_tensor.constant.storage())
|
||||
weak_st = StorageWeakRef(fake_tensor.constant._typed_storage())
|
||||
|
||||
# we need a map from a weak storage to all of its corresponding
|
||||
# constant tensors. python doesn't have the weak value equivalent
|
||||
@ -168,7 +168,7 @@ class FakeTensorConverter(object):
|
||||
def invalidate_constant_aliases(self, tensor):
|
||||
assert not isinstance(tensor, FakeTensor)
|
||||
|
||||
weak_st = StorageWeakRef(tensor.storage())
|
||||
weak_st = StorageWeakRef(tensor._typed_storage())
|
||||
if weak_st not in self.constant_storage_mapping:
|
||||
return
|
||||
|
||||
@ -1043,7 +1043,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce
|
||||
for e in tree_flatten((args, kwargs))[0]:
|
||||
if isinstance(e, torch.Tensor):
|
||||
if not e.is_sparse:
|
||||
storages.add(e.storage()._cdata)
|
||||
storages.add(e._typed_storage()._cdata)
|
||||
|
||||
# TODO: also check metadata change on inputs
|
||||
# proper aliasing/metadata relationship between outputs and inputs will
|
||||
@ -1053,7 +1053,7 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce
|
||||
if id(e) not in inp_impls and (
|
||||
isinstance(e, torch.Tensor)
|
||||
and not e.is_sparse
|
||||
and e.storage()._cdata in storages
|
||||
and e._typed_storage()._cdata in storages
|
||||
):
|
||||
raise orig_not_implemented_exception
|
||||
|
||||
|
||||
@ -18,12 +18,12 @@ aten = torch.ops.aten
|
||||
|
||||
def outputs_alias_inputs(outputs, inputs):
|
||||
input_storages = {
|
||||
inp.storage()._cdata
|
||||
inp._typed_storage()._cdata
|
||||
for inp in tree_flatten_only(torch.Tensor, inputs)
|
||||
if torch._C._has_storage(inp)
|
||||
}
|
||||
return any(
|
||||
torch._C._has_storage(out) and out.storage()._cdata in input_storages
|
||||
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
|
||||
for out in tree_flatten_only(torch.Tensor, outputs)
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ def output_alias_each_other(outputs):
|
||||
for out in tree_flatten_only(torch.Tensor, outputs):
|
||||
if not torch._C._has_storage(out):
|
||||
continue
|
||||
stor = out.storage()._cdata
|
||||
stor = out._typed_storage()._cdata
|
||||
if stor in storages:
|
||||
return True
|
||||
storages.add(stor)
|
||||
|
||||
@ -143,7 +143,7 @@ class MetaConverter:
|
||||
if t.is_sparse:
|
||||
weak_st = None
|
||||
else:
|
||||
weak_st = StorageWeakRef(t.storage())
|
||||
weak_st = StorageWeakRef(t._typed_storage())
|
||||
tensor_ref_key = WeakTensorRefKey(t)
|
||||
|
||||
def del_ten():
|
||||
@ -179,13 +179,9 @@ class MetaConverter:
|
||||
# Use a Weak Ref to s in order to not leak memory
|
||||
swr = StorageWeakRef(s)
|
||||
if swr not in self.storage_memo:
|
||||
self.storage_memo[swr] = (
|
||||
callback(
|
||||
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
|
||||
)
|
||||
.storage()
|
||||
.untyped()
|
||||
)
|
||||
self.storage_memo[swr] = callback(
|
||||
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
|
||||
)._storage()
|
||||
return self.storage_memo[swr]
|
||||
|
||||
# This function assumes that it's possible to do the conversion
|
||||
@ -362,7 +358,7 @@ class MetaConverter:
|
||||
# format here
|
||||
r = r.clone(memory_format=torch.preserve_format)
|
||||
|
||||
s = t.storage().untyped()
|
||||
s = t._storage()
|
||||
swr = StorageWeakRef(s)
|
||||
if (
|
||||
swr not in self.storage_memo
|
||||
@ -370,7 +366,7 @@ class MetaConverter:
|
||||
and r.storage_offset() == storage_offset
|
||||
):
|
||||
# You're normal and happy, install the fresh storage into the memo
|
||||
self.storage_memo[swr] = r.storage().untyped()
|
||||
self.storage_memo[swr] = r._storage()
|
||||
else:
|
||||
# You're in crazy town; somehow you gave us a tensor
|
||||
# that wasn't a view, but had nonzero storage offset,
|
||||
|
||||
@ -132,7 +132,7 @@ class Tensor(torch._C._TensorBase):
|
||||
"different type."
|
||||
)
|
||||
else:
|
||||
new_storage = self.storage().__deepcopy__(memo)
|
||||
new_storage = self._typed_storage()._deepcopy(memo)
|
||||
if self.is_quantized:
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[
|
||||
@ -163,7 +163,9 @@ class Tensor(torch._C._TensorBase):
|
||||
# need to wrap with TypedStorage
|
||||
new_tensor = torch._utils._rebuild_qtensor(
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=new_storage.untyped(), dtype=self.dtype
|
||||
wrap_storage=new_storage._untyped_storage,
|
||||
dtype=self.dtype,
|
||||
_internal=True,
|
||||
),
|
||||
self.storage_offset(),
|
||||
self.size(),
|
||||
@ -257,7 +259,17 @@ class Tensor(torch._C._TensorBase):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.storage, (self,), self)
|
||||
|
||||
return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
|
||||
torch.storage._warn_typed_storage_removal()
|
||||
return self._typed_storage()
|
||||
|
||||
# For internal use only, to avoid raising deprecation warning
|
||||
def _typed_storage(self):
|
||||
_storage = self._storage()
|
||||
if isinstance(_storage, torch.TypedStorage):
|
||||
_storage = _storage._untyped_storage
|
||||
return torch.TypedStorage(
|
||||
wrap_storage=_storage, dtype=self.dtype, _internal=True
|
||||
)
|
||||
|
||||
def _reduce_ex_internal(self, proto):
|
||||
check_serializing_named_tensor(self)
|
||||
@ -331,7 +343,9 @@ class Tensor(torch._C._TensorBase):
|
||||
# need to wrap with TypedStorage
|
||||
args_qtensor = (
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=self.storage().untyped(), dtype=self.dtype
|
||||
wrap_storage=self._typed_storage()._untyped_storage,
|
||||
dtype=self.dtype,
|
||||
_internal=True,
|
||||
),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
@ -389,7 +403,9 @@ class Tensor(torch._C._TensorBase):
|
||||
# need to wrap with TypedStorage
|
||||
args = (
|
||||
torch.storage.TypedStorage(
|
||||
wrap_storage=self.storage().untyped(), dtype=self.dtype
|
||||
wrap_storage=self._typed_storage()._untyped_storage,
|
||||
dtype=self.dtype,
|
||||
_internal=True,
|
||||
),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
@ -607,7 +623,7 @@ class Tensor(torch._C._TensorBase):
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.is_shared, (self,), self)
|
||||
return self.storage().is_shared()
|
||||
return self._typed_storage()._is_shared()
|
||||
|
||||
def share_memory_(self):
|
||||
r"""Moves the underlying storage to shared memory.
|
||||
@ -617,7 +633,7 @@ class Tensor(torch._C._TensorBase):
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.share_memory_, (self,), self)
|
||||
self.storage().share_memory_()
|
||||
self._typed_storage()._share_memory_()
|
||||
return self
|
||||
|
||||
def __reversed__(self):
|
||||
@ -1059,7 +1075,9 @@ class Tensor(torch._C._TensorBase):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.storage_type, (self,), self)
|
||||
|
||||
return self.storage()._get_legacy_storage_class()
|
||||
torch.storage._warn_typed_storage_removal()
|
||||
|
||||
return self._typed_storage()._get_legacy_storage_class()
|
||||
|
||||
def refine_names(self, *names):
|
||||
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
|
||||
|
||||
@ -143,8 +143,8 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
# be a TypedStorage
|
||||
def _rebuild_tensor(storage, storage_offset, size, stride):
|
||||
# first construct a tensor with the correct dtype/device
|
||||
t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
|
||||
return t.set_(storage.untyped(), storage_offset, size, stride)
|
||||
t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device)
|
||||
return t.set_(storage._untyped_storage, storage_offset, size, stride)
|
||||
|
||||
|
||||
def _rebuild_tensor_v2(
|
||||
|
||||
@ -135,7 +135,7 @@ at::Storage createStorageGetType(
|
||||
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj));
|
||||
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
|
||||
|
||||
untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
|
||||
untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage");
|
||||
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
|
||||
Py_DECREF(untyped_storage_obj);
|
||||
|
||||
|
||||
@ -737,11 +737,12 @@ class _CudaBase(object):
|
||||
|
||||
__new__ = _lazy_new
|
||||
|
||||
from torch.storage import _LegacyStorage
|
||||
from torch.storage import _LegacyStorage, _warn_typed_storage_removal
|
||||
|
||||
class _CudaLegacyStorage(_LegacyStorage):
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, **kwargs):
|
||||
_warn_typed_storage_removal()
|
||||
raise RuntimeError('from_buffer: Not available for CUDA storage')
|
||||
|
||||
@classmethod
|
||||
@ -755,61 +756,121 @@ class _CudaLegacyStorage(_LegacyStorage):
|
||||
class ByteStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.uint8
|
||||
|
||||
class DoubleStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.double
|
||||
|
||||
class FloatStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.float
|
||||
|
||||
class HalfStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.half
|
||||
|
||||
class LongStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.long
|
||||
|
||||
class IntStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.int
|
||||
|
||||
class ShortStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.short
|
||||
|
||||
class CharStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.int8
|
||||
|
||||
class BoolStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.bool
|
||||
|
||||
class BFloat16Storage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
class ComplexDoubleStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.cdouble
|
||||
|
||||
class ComplexFloatStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._dtype
|
||||
|
||||
@classproperty
|
||||
def _dtype(self):
|
||||
return torch.cfloat
|
||||
|
||||
del _LegacyStorage
|
||||
|
||||
@ -89,7 +89,7 @@ def find_input_mutations(g):
|
||||
mutated_inputs = set()
|
||||
for n in g.nodes:
|
||||
if n.op == 'placeholder':
|
||||
inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx)
|
||||
inputs[StorageWeakRef(n.meta[FK]._typed_storage())].add(input_idx)
|
||||
input_idx += 1
|
||||
elif n.op == 'call_function':
|
||||
if n.target is operator.getitem:
|
||||
@ -109,7 +109,7 @@ def find_input_mutations(g):
|
||||
if mut_arg:
|
||||
# TODO: not correct for args that contain tensors in a struct
|
||||
# like list
|
||||
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())]
|
||||
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK]._typed_storage())]
|
||||
# TODO: error on unrecognized nodes
|
||||
return mutated_inputs
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ DEFAULT_SUFIX = ".distcp"
|
||||
|
||||
def _trim(tensor: torch.Tensor) -> torch.Tensor:
|
||||
tensor = tensor.detach().cpu()
|
||||
if tensor.storage().size() != tensor.numel():
|
||||
if tensor._typed_storage()._size() != tensor.numel():
|
||||
tensor = tensor.clone()
|
||||
return tensor
|
||||
|
||||
|
||||
@ -1896,7 +1896,7 @@ def all_gather_multigpu(
|
||||
def _object_to_tensor(obj, device):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
|
||||
@ -69,14 +69,14 @@ def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
|
||||
bool: ``True`` if this method allocated storage and ``False`` if the
|
||||
storage was already allocated.
|
||||
"""
|
||||
already_allocated = tensor.storage().size() == size.numel()
|
||||
already_allocated = tensor._typed_storage()._size() == size.numel()
|
||||
if not already_allocated:
|
||||
tensor_storage_size = tensor.storage().size()
|
||||
tensor_storage_size = tensor._typed_storage()._size()
|
||||
p_assert(
|
||||
tensor_storage_size == 0,
|
||||
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
|
||||
)
|
||||
tensor.storage().resize_(size.numel())
|
||||
tensor._typed_storage()._resize_(size.numel())
|
||||
return not already_allocated
|
||||
|
||||
|
||||
@ -89,23 +89,23 @@ def _free_storage(tensor: torch.Tensor) -> bool:
|
||||
bool: ``True`` if the method freed the storage and ``False`` if the
|
||||
storage was already freed.
|
||||
"""
|
||||
already_freed = tensor.storage().size() == 0
|
||||
already_freed = tensor._typed_storage()._size() == 0
|
||||
if not already_freed:
|
||||
p_assert(
|
||||
tensor.storage_offset() == 0,
|
||||
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
|
||||
f"storage offset: {tensor.storage_offset()}\n"
|
||||
f"storage size: {tensor.storage().size()}\n"
|
||||
f"storage size: {tensor._typed_storage()._size()}\n"
|
||||
f"tensor shape: {tensor.shape}",
|
||||
)
|
||||
tensor.storage().resize_(0)
|
||||
tensor._typed_storage()._resize_(0)
|
||||
return not already_freed
|
||||
|
||||
|
||||
def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
|
||||
"""Returns if ``x`` and ``y`` share the same storage."""
|
||||
# NOTE: CPU and GPU tensors are ensured to have different data pointers.
|
||||
return x.storage().data_ptr() == y.storage().data_ptr()
|
||||
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
|
||||
|
||||
|
||||
def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
|
||||
|
||||
@ -493,7 +493,7 @@ class FlatParamHandle:
|
||||
flat_param.storage_offset() == 0,
|
||||
"The `FlatParameter` is not the sole occupant of its storage",
|
||||
)
|
||||
orig_storage = flat_param.storage()
|
||||
orig_storage = flat_param._typed_storage()
|
||||
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
|
||||
flat_param, self.rank, self.world_size
|
||||
)
|
||||
@ -501,8 +501,8 @@ class FlatParamHandle:
|
||||
start = sharded_flat_param.numel() * self.rank
|
||||
end = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
|
||||
self._init_shard_metadata(numel_padded, start, end)
|
||||
if orig_storage.size() > 0:
|
||||
orig_storage.resize_(0)
|
||||
if orig_storage._size() > 0:
|
||||
orig_storage._resize_(0)
|
||||
if self._use_orig_params:
|
||||
self._use_sharded_views()
|
||||
|
||||
@ -838,7 +838,7 @@ class FlatParamHandle:
|
||||
return False
|
||||
unsharded_flat_param = self._get_padded_unsharded_flat_param()
|
||||
already_unsharded = (
|
||||
unsharded_flat_param.storage().size() == unsharded_flat_param.numel()
|
||||
unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel()
|
||||
)
|
||||
return not already_unsharded
|
||||
|
||||
@ -1141,9 +1141,9 @@ class FlatParamHandle:
|
||||
# the padded unsharded flattened parameter as expected
|
||||
# NOTE: This check is not strictly needed for correctness but is a
|
||||
# useful sanity check since the tensor should only be used internally.
|
||||
unpadded_storage_ptr = self.flat_param.storage().data_ptr()
|
||||
unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr()
|
||||
padded_storage_ptr = (
|
||||
self._get_padded_unsharded_flat_param().storage().data_ptr()
|
||||
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
|
||||
)
|
||||
p_assert(
|
||||
unpadded_storage_ptr == padded_storage_ptr,
|
||||
@ -1824,7 +1824,7 @@ class FlatParamHandle:
|
||||
|
||||
@staticmethod
|
||||
def _check_storage_freed(tensor: Tensor):
|
||||
storage_size: int = tensor.storage().size()
|
||||
storage_size: int = tensor._typed_storage()._size()
|
||||
p_assert(
|
||||
storage_size == 0,
|
||||
f"Expects storage to be freed but got storage with size {storage_size}",
|
||||
@ -1832,7 +1832,7 @@ class FlatParamHandle:
|
||||
|
||||
@staticmethod
|
||||
def _check_storage_allocated(tensor: Tensor):
|
||||
storage_size: int = tensor.storage().size()
|
||||
storage_size: int = tensor._typed_storage()._size()
|
||||
p_assert(storage_size > 0, "Expects storage to be allocated")
|
||||
|
||||
def _check_low_precision_shard(self):
|
||||
|
||||
@ -107,7 +107,7 @@ def profile_sizes(
|
||||
latent_size = memory_after - memory_before
|
||||
|
||||
# Analyze size of parameters.
|
||||
param_size = sum(p.storage().nbytes() for p in layer.parameters())
|
||||
param_size = sum(p._typed_storage()._nbytes() for p in layer.parameters())
|
||||
|
||||
# Combine size of parameters and activations with normalize scales.
|
||||
size = latent_size * latent_scale + param_size * param_scale
|
||||
|
||||
@ -104,7 +104,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
|
||||
#
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/27366
|
||||
#
|
||||
tensor = tensor.new_empty([0]).set_(tensor.storage())
|
||||
tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
|
||||
|
||||
# Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
|
||||
tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]
|
||||
|
||||
@ -100,8 +100,8 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||
# Assert here that this is actually the case, and their storages are the same.
|
||||
assert isinstance(node.meta['fake_result'], FakeTensor)
|
||||
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
|
||||
view_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result'].storage())
|
||||
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
||||
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
|
||||
assert view_storage == base_storage
|
||||
return result
|
||||
|
||||
@ -176,7 +176,7 @@ _VIEW_INVERSE_MAP = {
|
||||
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
||||
def _add_if_tensor(x, set_):
|
||||
if isinstance(x, FakeTensor):
|
||||
set_.add(StorageWeakRef(x.storage()))
|
||||
set_.add(StorageWeakRef(x._typed_storage()))
|
||||
|
||||
nodes_used_after = set()
|
||||
for t in tensor_aliases:
|
||||
@ -452,7 +452,7 @@ def reinplace(gm, *sample_args):
|
||||
# Useful debug printing
|
||||
# def _print(x):
|
||||
# if isinstance(x, FakeTensor):
|
||||
# print(f'fake_result: {StorageWeakRef(x.storage()).cdata}')
|
||||
# print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
|
||||
|
||||
# for n in gm.graph.nodes:
|
||||
# print(n.format_node())
|
||||
@ -468,7 +468,10 @@ def reinplace(gm, *sample_args):
|
||||
# so we know not to re-inplace them.
|
||||
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
||||
# on programs that mutate inputs.
|
||||
input_storages = set(StorageWeakRef(node.meta['fake_result'].storage()) for node in gm.graph.nodes if node.op == 'placeholder')
|
||||
input_storages = set(
|
||||
StorageWeakRef(
|
||||
node.meta['fake_result']._typed_storage()
|
||||
) for node in gm.graph.nodes if node.op == 'placeholder')
|
||||
|
||||
|
||||
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||
@ -478,7 +481,7 @@ def reinplace(gm, *sample_args):
|
||||
# Tree-mapping because some ops can return lists of tensors.
|
||||
def _add_to_map(x):
|
||||
if isinstance(x, FakeTensor):
|
||||
storage_to_nodes[StorageWeakRef(x.storage())].add(n)
|
||||
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
|
||||
tree_map(_add_to_map, n.meta['fake_result'])
|
||||
|
||||
# inplace-ify functional ops, subject to the constraints written below.
|
||||
@ -529,7 +532,7 @@ def reinplace(gm, *sample_args):
|
||||
|
||||
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
||||
self_arg_name = self_arg.name
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
||||
if self_arg_storage in input_storages:
|
||||
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
||||
continue
|
||||
@ -539,7 +542,7 @@ def reinplace(gm, *sample_args):
|
||||
# so we prevent re-inplacing in this case.
|
||||
continue
|
||||
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result'].storage())
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
||||
self_aliases = storage_to_nodes[self_arg_storage]
|
||||
|
||||
# First, we find all later usages of any of the aliases of self_arg.
|
||||
@ -594,7 +597,7 @@ def reinplace(gm, *sample_args):
|
||||
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
||||
# up to date here, but I'm not sure how easy it is to do.
|
||||
# Maybe it's fine to wait until the end of the pass to update it.
|
||||
curr_node_storage = StorageWeakRef(node.meta['fake_result'].storage())
|
||||
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
||||
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
||||
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
||||
|
||||
@ -624,8 +627,14 @@ def reinplace(gm, *sample_args):
|
||||
old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
|
||||
node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
|
||||
|
||||
old_res_storage = set(StorageWeakRef(x.storage()) for x in old_flattened_res if isinstance(x, FakeTensor))
|
||||
node_res_storage = set(StorageWeakRef(x.storage()) for x in node_flattened_res if isinstance(x, FakeTensor))
|
||||
old_res_storage = set(
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in old_flattened_res if isinstance(x, FakeTensor))
|
||||
node_res_storage = set(
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in node_flattened_res if isinstance(x, FakeTensor))
|
||||
|
||||
# This will happen if we're updating a view op, e.g.
|
||||
# e.g. replacing
|
||||
@ -639,7 +648,10 @@ def reinplace(gm, *sample_args):
|
||||
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
|
||||
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
|
||||
new_flattened_res, _ = tree_flatten(new.meta['fake_result'])
|
||||
new_res_storage = set(StorageWeakRef(x.storage()) for x in new_flattened_res if isinstance(x, FakeTensor))
|
||||
new_res_storage = set(
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in new_flattened_res if isinstance(x, FakeTensor))
|
||||
assert len(new_res_storage) == 1
|
||||
(old_ref,) = old_res_storage
|
||||
(new_ref,) = new_res_storage
|
||||
|
||||
@ -113,7 +113,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
|
||||
# If storage_handle is None, storage points to nullptr.
|
||||
if storage_handle is None or storage_size_bytes == 0:
|
||||
storage = storage_cls(0, dtype=dtype, device=storage_device)
|
||||
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
|
||||
else:
|
||||
storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
|
||||
if storage is None:
|
||||
@ -132,8 +132,10 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
# We already ref counting this Storage, but producer needs new ref-counters to be released.
|
||||
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
|
||||
|
||||
_storage = storage if isinstance(storage, torch.UntypedStorage) else storage._untyped_storage
|
||||
|
||||
t = torch._utils._rebuild_tensor(
|
||||
torch.storage.TypedStorage(wrap_storage=storage.untyped(), dtype=dtype),
|
||||
torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
|
||||
tensor_offset, tensor_size, tensor_stride)
|
||||
|
||||
if tensor_cls == torch.nn.parameter.Parameter:
|
||||
@ -147,7 +149,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
|
||||
|
||||
def reduce_tensor(tensor):
|
||||
storage = tensor.storage()
|
||||
storage = tensor._typed_storage()
|
||||
|
||||
if tensor.requires_grad and not tensor.is_leaf:
|
||||
raise RuntimeError("Cowardly refusing to serialize non-leaf tensor which requires_grad, "
|
||||
@ -248,7 +250,7 @@ def reduce_tensor(tensor):
|
||||
# eliminated it so that we could just use tensor views to implement the same
|
||||
# thing.
|
||||
#
|
||||
if storage.is_cuda:
|
||||
if storage._untyped_storage.device.type == 'cuda':
|
||||
(device,
|
||||
handle,
|
||||
storage_size_bytes,
|
||||
@ -325,7 +327,8 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
||||
untyped_storage: torch.UntypedStorage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
||||
storage = torch.TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
shared_cache[handle] = StorageWeakRef(storage)
|
||||
return storage._shared_decref()
|
||||
|
||||
@ -334,18 +337,18 @@ def rebuild_storage_empty(cls):
|
||||
return cls()
|
||||
|
||||
def rebuild_typed_storage(storage, dtype):
|
||||
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype)
|
||||
return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
|
||||
|
||||
# Use for torch.storage.TypedStorage
|
||||
def reduce_typed_storage(storage):
|
||||
return (rebuild_typed_storage, (storage._storage, storage.dtype))
|
||||
return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
|
||||
|
||||
def rebuild_typed_storage_child(storage, storage_type):
|
||||
return storage_type(wrap_storage=storage)
|
||||
return storage_type(wrap_storage=storage, _internal=True)
|
||||
|
||||
# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
|
||||
def reduce_typed_storage_child(storage):
|
||||
return (rebuild_typed_storage_child, (storage._storage, type(storage)))
|
||||
return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
|
||||
|
||||
def reduce_storage(storage):
|
||||
from . import get_sharing_strategy
|
||||
|
||||
@ -273,6 +273,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
Tensor.to_sparse_csc,
|
||||
Tensor.to_sparse_bsr,
|
||||
Tensor.to_sparse_bsc,
|
||||
Tensor._typed_storage,
|
||||
Tensor._reduce_ex_internal,
|
||||
Tensor._fix_weakref,
|
||||
Tensor._make_wrapper_subclass,
|
||||
|
||||
@ -887,7 +887,7 @@ class PackageExporter:
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
untyped_storage = obj._storage
|
||||
untyped_storage = obj._untyped_storage
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
storage_numel = obj.size()
|
||||
|
||||
@ -208,14 +208,14 @@ class PackageImporter(Importer):
|
||||
name = f"{key}.storage"
|
||||
|
||||
if storage_context.has_storage(name):
|
||||
storage = storage_context.get_storage(name, dtype).storage()
|
||||
storage = storage_context.get_storage(name, dtype)._typed_storage()
|
||||
else:
|
||||
tensor = self.zip_reader.get_storage_from_record(
|
||||
".data/" + name, size, dtype
|
||||
)
|
||||
if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
|
||||
storage_context.add_storage(name, tensor)
|
||||
storage = tensor.storage()
|
||||
storage = tensor._typed_storage()
|
||||
loaded_storages[key] = restore_location(storage, location)
|
||||
|
||||
def persistent_load(saved_id):
|
||||
@ -239,7 +239,7 @@ class PackageImporter(Importer):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
return torch.storage.TypedStorage(
|
||||
wrap_storage=storage.untyped(), dtype=dtype
|
||||
wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
|
||||
)
|
||||
elif typename == "reduce_package":
|
||||
# to fix BC breaking change, objects on this load path
|
||||
|
||||
@ -469,12 +469,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
storage = obj._untyped_storage
|
||||
storage_dtype = obj.dtype
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type_str = obj._pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
dtype = obj.dtype
|
||||
storage_numel = obj.size()
|
||||
storage_numel = obj._size()
|
||||
|
||||
elif isinstance(obj, torch.UntypedStorage):
|
||||
storage = obj
|
||||
@ -597,11 +597,11 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
if isinstance(obj, torch.storage.TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
storage = obj._untyped_storage
|
||||
storage_dtype = obj.dtype
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type_str = obj._pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
storage_numel = obj.size()
|
||||
storage_numel = obj._size()
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
@ -893,14 +893,15 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
for i in range(num_storages):
|
||||
args = pickle_module.load(f, **pickle_load_args)
|
||||
key, location, storage_type = args
|
||||
dtype = storage_type.dtype
|
||||
dtype = storage_type._dtype
|
||||
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
|
||||
obj = restore_location(obj, location)
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[key] = torch.storage.TypedStorage(
|
||||
wrap_storage=obj,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
|
||||
storage_views = pickle_module.load(f, **pickle_load_args)
|
||||
for target_cdata, root_cdata, offset, numel in storage_views:
|
||||
@ -910,8 +911,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
|
||||
wrap_storage=root._storage[offset_bytes:offset_bytes + numel * element_size],
|
||||
dtype=root.dtype)
|
||||
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
|
||||
dtype=root.dtype,
|
||||
_internal=True)
|
||||
|
||||
tar.extract('tensors', path=tmpdir)
|
||||
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
|
||||
@ -927,7 +929,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
||||
storage_offset, = struct.unpack('<q', f.read(8))
|
||||
tensor = torch.tensor([], dtype=storage.dtype).set_(
|
||||
storage._storage, storage_offset, numel, stride)
|
||||
storage._untyped_storage, storage_offset, numel, stride)
|
||||
deserialized_objects[key] = tensor
|
||||
|
||||
pickle_file = tar.extractfile('pickle')
|
||||
@ -962,7 +964,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[root_key] = torch.storage.TypedStorage(
|
||||
wrap_storage=restore_location(obj, location),
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
|
||||
typed_storage = deserialized_objects[root_key]
|
||||
if view_metadata is not None:
|
||||
@ -973,8 +976,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
deserialized_objects[view_key] = torch.storage.TypedStorage(
|
||||
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
|
||||
dtype=dtype)
|
||||
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
res = deserialized_objects[view_key]
|
||||
|
||||
else:
|
||||
@ -1023,7 +1027,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||
for key in deserialized_storage_keys:
|
||||
assert key in deserialized_objects
|
||||
typed_storage = deserialized_objects[key]
|
||||
typed_storage._storage._set_from_file(
|
||||
typed_storage._untyped_storage._set_from_file(
|
||||
f, offset, f_should_read_directly,
|
||||
torch._utils._element_size(typed_storage.dtype))
|
||||
if offset is not None:
|
||||
@ -1082,12 +1086,13 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
|
||||
def load_tensor(dtype, numel, key, location):
|
||||
name = f'data/{key}'
|
||||
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage).storage().untyped()
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
loaded_storages[key] = torch.storage.TypedStorage(
|
||||
wrap_storage=restore_location(storage, location),
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
|
||||
def persistent_load(saved_id):
|
||||
assert isinstance(saved_id, tuple)
|
||||
|
||||
252
torch/storage.py
252
torch/storage.py
@ -7,6 +7,7 @@ from typing import Any, TypeVar, Type, Union, cast
|
||||
import copy
|
||||
import collections
|
||||
from functools import lru_cache
|
||||
import warnings
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
@ -131,7 +132,7 @@ class _StorageBase(object):
|
||||
def _to(self, dtype):
|
||||
if not isinstance(dtype, torch.dtype):
|
||||
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
|
||||
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
|
||||
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
|
||||
if storage.data_ptr() == self.data_ptr():
|
||||
storage = storage.clone()
|
||||
return storage
|
||||
@ -297,7 +298,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
return tmp_tensor.storage().untyped()
|
||||
return tmp_tensor._typed_storage()._untyped_storage
|
||||
|
||||
def _isint(x):
|
||||
if HAS_NUMPY:
|
||||
@ -305,16 +306,32 @@ def _isint(x):
|
||||
else:
|
||||
return isinstance(x, int)
|
||||
|
||||
def _warn_typed_storage_removal():
|
||||
message = (
|
||||
"TypedStorage is deprecated. It will be removed in the future and "
|
||||
"UntypedStorage will be the only storage class. This should only matter "
|
||||
"to you if you are using storages directly."
|
||||
)
|
||||
warnings.warn(message, UserWarning)
|
||||
|
||||
class TypedStorage:
|
||||
is_sparse = False
|
||||
|
||||
dtype: torch.dtype
|
||||
|
||||
@property
|
||||
def _dtype(self):
|
||||
return self.dtype
|
||||
|
||||
def fill_(self, value):
|
||||
self[0:len(self)] = value
|
||||
_warn_typed_storage_removal()
|
||||
self._setitem(slice(0, self._size()), value)
|
||||
return self
|
||||
|
||||
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
|
||||
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=False):
|
||||
if not _internal:
|
||||
_warn_typed_storage_removal()
|
||||
|
||||
if cls == torch.storage._LegacyStorage:
|
||||
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
|
||||
|
||||
@ -353,8 +370,9 @@ class TypedStorage:
|
||||
|
||||
return TypedStorage(
|
||||
*args,
|
||||
dtype=cls.dtype,
|
||||
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu')
|
||||
dtype=cls._dtype,
|
||||
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
|
||||
_internal=True)
|
||||
|
||||
else:
|
||||
if len(args) != 0:
|
||||
@ -379,9 +397,12 @@ class TypedStorage:
|
||||
return TypedStorage(
|
||||
*args,
|
||||
wrap_storage=wrap_storage,
|
||||
dtype=cls.dtype)
|
||||
dtype=cls.dtype,
|
||||
_internal=True)
|
||||
|
||||
def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
|
||||
def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=False):
|
||||
if not _internal:
|
||||
_warn_typed_storage_removal()
|
||||
arg_error_msg = (
|
||||
'TypedStorage.__init__ received an invalid combination '
|
||||
'of arguments. Expected one of:\n'
|
||||
@ -419,7 +440,7 @@ class TypedStorage:
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
|
||||
|
||||
self._storage = wrap_storage
|
||||
self._untyped_storage = wrap_storage
|
||||
|
||||
else:
|
||||
self.dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
@ -430,13 +451,13 @@ class TypedStorage:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
|
||||
if len(args) == 0:
|
||||
self._storage = torch.UntypedStorage(device=device)
|
||||
self._untyped_storage = torch.UntypedStorage(device=device)
|
||||
|
||||
elif len(args) == 1:
|
||||
if _isint(args[0]):
|
||||
self._storage = torch.UntypedStorage(int(args[0]) * self.element_size(), device=device)
|
||||
self._untyped_storage = torch.UntypedStorage(int(args[0]) * self._element_size(), device=device)
|
||||
elif isinstance(args[0], collections.abc.Sequence):
|
||||
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
|
||||
self._untyped_storage = _get_storage_from_sequence(args[0], self.dtype, device)
|
||||
else:
|
||||
raise TypeError(
|
||||
arg_error_msg +
|
||||
@ -447,30 +468,35 @@ class TypedStorage:
|
||||
arg_error_msg +
|
||||
"\nToo many positional arguments")
|
||||
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
def untyped(self):
|
||||
"""Returns the internal :class:`torch.UntypedStorage`"""
|
||||
return self._storage
|
||||
_warn_typed_storage_removal()
|
||||
return self._untyped_storage
|
||||
|
||||
def _new_wrapped_storage(self, untyped_storage):
|
||||
assert type(untyped_storage) == torch.UntypedStorage
|
||||
|
||||
if type(self) == TypedStorage:
|
||||
return TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
|
||||
return TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=self.dtype,
|
||||
_internal=True)
|
||||
else:
|
||||
return type(self)(wrap_storage=untyped_storage)
|
||||
|
||||
def __len__(self):
|
||||
return self._storage.nbytes() // self.element_size()
|
||||
_warn_typed_storage_removal()
|
||||
return self._size()
|
||||
|
||||
def _maybe_wrap_index(self, idx, is_stop=False):
|
||||
if idx is None:
|
||||
if is_stop:
|
||||
return self.size()
|
||||
return self._size()
|
||||
else:
|
||||
return 0
|
||||
|
||||
@ -479,20 +505,24 @@ class TypedStorage:
|
||||
raise TypeError(
|
||||
f"can't index a {type(self)} with {type(idx)}")
|
||||
if is_stop:
|
||||
if (idx > self.size()) or (idx < -self.size()):
|
||||
if (idx > self._size()) or (idx < -self._size()):
|
||||
raise IndexError(
|
||||
f'index {idx} out of range for storage of size {self.size()}')
|
||||
if idx > 0:
|
||||
return idx
|
||||
else:
|
||||
return idx % self.size()
|
||||
return idx % self._size()
|
||||
else:
|
||||
if (idx >= self.size()) or (idx < -self.size()):
|
||||
if (idx >= self._size()) or (idx < -self._size()):
|
||||
raise IndexError(
|
||||
f'index {idx} out of range for storage of size {self.size()}')
|
||||
return idx % self.size()
|
||||
return idx % self._size()
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
_warn_typed_storage_removal()
|
||||
return self._setitem(idx, value)
|
||||
|
||||
def _setitem(self, idx, value):
|
||||
if not isinstance(idx, (int, slice)):
|
||||
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
|
||||
if torch.is_storage(value):
|
||||
@ -506,16 +536,22 @@ class TypedStorage:
|
||||
torch.qint8: torch.int8
|
||||
}
|
||||
tmp_dtype = interpret_dtypes[self.dtype]
|
||||
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(TypedStorage(
|
||||
wrap_storage=self._storage,
|
||||
dtype=tmp_dtype))
|
||||
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device)
|
||||
tmp_tensor.set_(TypedStorage(
|
||||
wrap_storage=self._untyped_storage,
|
||||
dtype=tmp_dtype,
|
||||
_internal=True))
|
||||
else:
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
|
||||
|
||||
tmp_tensor[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.device.type == 'meta':
|
||||
_warn_typed_storage_removal()
|
||||
return self._getitem(idx)
|
||||
|
||||
def _getitem(self, idx):
|
||||
if self._untyped_storage.device.type == 'meta':
|
||||
raise NotImplementedError("Not available for 'meta' device type")
|
||||
|
||||
# NOTE: Before TypedStorage existed, indexing with a slice used to be
|
||||
@ -536,21 +572,32 @@ class TypedStorage:
|
||||
torch.qint8: torch.int8
|
||||
}
|
||||
return TypedStorage(
|
||||
wrap_storage=self._storage,
|
||||
dtype=interpret_dtypes[self.dtype])[idx]
|
||||
wrap_storage=self._untyped_storage,
|
||||
dtype=interpret_dtypes[self.dtype],
|
||||
_internal=True)._getitem(idx)
|
||||
|
||||
idx_wrapped = self._maybe_wrap_index(idx)
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
|
||||
return tmp_tensor[idx_wrapped].item()
|
||||
|
||||
def copy_(self, source: T, non_blocking: bool = None):
|
||||
self._storage.copy_(source.untyped(), non_blocking)
|
||||
_warn_typed_storage_removal()
|
||||
if isinstance(source, TypedStorage):
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
|
||||
else:
|
||||
self._untyped_storage.copy_(source, non_blocking)
|
||||
return self
|
||||
|
||||
def nbytes(self):
|
||||
return self._storage.nbytes()
|
||||
_warn_typed_storage_removal()
|
||||
return self._nbytes()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _nbytes(self):
|
||||
return self._untyped_storage.nbytes()
|
||||
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
|
||||
_warn_typed_storage_removal()
|
||||
if dtype is None:
|
||||
legacy_class = self._get_legacy_storage_class()
|
||||
|
||||
@ -560,21 +607,29 @@ class TypedStorage:
|
||||
return '.'.join([self.__module__, type(self).__name__])
|
||||
|
||||
else:
|
||||
return self._storage.type(dtype, non_blocking)
|
||||
return self._untyped_storage.type(dtype, non_blocking)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage: torch.UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
|
||||
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def element_size(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._element_size()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _element_size(self):
|
||||
return torch._utils._element_size(self.dtype)
|
||||
|
||||
def get_device(self) -> int:
|
||||
return self._storage.get_device()
|
||||
_warn_typed_storage_removal()
|
||||
return self._untyped_storage.get_device()
|
||||
|
||||
def __str__(self):
|
||||
_warn_typed_storage_removal()
|
||||
info_str = (
|
||||
f'[{torch.typename(self)}(dtype={self.dtype}, '
|
||||
f'device={self.device}) of size {len(self)}]')
|
||||
@ -585,35 +640,48 @@ class TypedStorage:
|
||||
return data_str + '\n' + info_str
|
||||
|
||||
def __repr__(self):
|
||||
_warn_typed_storage_removal()
|
||||
return str(self)
|
||||
|
||||
def __iter__(self):
|
||||
_warn_typed_storage_removal()
|
||||
return iter(map(lambda i: self[i], range(self.size())))
|
||||
|
||||
def __copy__(self):
|
||||
return self._new_wrapped_storage(copy.copy(self._storage))
|
||||
_warn_typed_storage_removal()
|
||||
return self._new_wrapped_storage(copy.copy(self._untyped_storage))
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
|
||||
_warn_typed_storage_removal()
|
||||
return self._deepcopy(memo)
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _deepcopy(self, memo):
|
||||
return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
|
||||
|
||||
def __sizeof__(self):
|
||||
_warn_typed_storage_removal()
|
||||
return super(TypedStorage, self).__sizeof__() + self.nbytes()
|
||||
|
||||
def clone(self):
|
||||
"""Returns a copy of this storage"""
|
||||
return self._new_wrapped_storage(self._storage.clone())
|
||||
_warn_typed_storage_removal()
|
||||
return self._new_wrapped_storage(self._untyped_storage.clone())
|
||||
|
||||
def tolist(self):
|
||||
"""Returns a list containing the elements of this storage"""
|
||||
_warn_typed_storage_removal()
|
||||
return list(self)
|
||||
|
||||
def cpu(self):
|
||||
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
||||
return self._new_wrapped_storage(self._storage.cpu())
|
||||
_warn_typed_storage_removal()
|
||||
return self._new_wrapped_storage(self._untyped_storage.cpu())
|
||||
|
||||
def pin_memory(self):
|
||||
"""Coppies the storage to pinned memory, if it's not already pinned."""
|
||||
return self._new_wrapped_storage(self._storage.pin_memory())
|
||||
_warn_typed_storage_removal()
|
||||
return self._new_wrapped_storage(self._untyped_storage.pin_memory())
|
||||
|
||||
def share_memory_(self):
|
||||
"""Moves the storage to shared memory.
|
||||
@ -624,7 +692,12 @@ class TypedStorage:
|
||||
|
||||
Returns: self
|
||||
"""
|
||||
self._storage.share_memory_()
|
||||
_warn_typed_storage_removal()
|
||||
return self._share_memory_()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _share_memory_(self):
|
||||
self._untyped_storage.share_memory_()
|
||||
return self
|
||||
|
||||
def _new_shared(self, size, *, device=None):
|
||||
@ -632,25 +705,37 @@ class TypedStorage:
|
||||
if device is None:
|
||||
device = 'cpu'
|
||||
device = torch.device(device)
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * self.element_size(), device=device)
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device)
|
||||
return TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
_internal=True)
|
||||
|
||||
@property
|
||||
def _cdata(self):
|
||||
return self._storage._cdata
|
||||
return self._untyped_storage._cdata
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._storage.device
|
||||
_warn_typed_storage_removal()
|
||||
return self._untyped_storage.device
|
||||
|
||||
def size(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._size()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _size(self):
|
||||
# NB: don't indirect through __len__, as that requires
|
||||
# an int to be returned
|
||||
return self.nbytes() // self.element_size()
|
||||
return self._untyped_storage.nbytes() // self._element_size()
|
||||
|
||||
def pickle_storage_type(self):
|
||||
_warn_typed_storage_removal()
|
||||
return self._pickle_storage_type()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _pickle_storage_type(self):
|
||||
try:
|
||||
return _dtype_to_storage_type_map()[self.dtype]
|
||||
except KeyError:
|
||||
@ -662,20 +747,35 @@ class TypedStorage:
|
||||
return (_load_from_bytes, (b.getvalue(),))
|
||||
|
||||
def data_ptr(self):
|
||||
return self._storage.data_ptr()
|
||||
_warn_typed_storage_removal()
|
||||
return self._data_ptr()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _data_ptr(self):
|
||||
return self._untyped_storage.data_ptr()
|
||||
|
||||
def resize_(self, size):
|
||||
self._storage.resize_(size * self.element_size())
|
||||
_warn_typed_storage_removal()
|
||||
self._resize_(size)
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _resize_(self, size):
|
||||
self._untyped_storage.resize_(size * self._element_size())
|
||||
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs):
|
||||
return UntypedStorage._free_weak_ref(*args, **kwargs)
|
||||
|
||||
def _weak_ref(self, *args, **kwargs):
|
||||
return self._storage._weak_ref(*args, **kwargs)
|
||||
return self._untyped_storage._weak_ref(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
|
||||
def from_buffer(cls, *args, **kwargs):
|
||||
_warn_typed_storage_removal()
|
||||
return cls._from_buffer(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
|
||||
if cls == TypedStorage:
|
||||
dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
device = torch.device('cpu' if device is None else device)
|
||||
@ -693,65 +793,80 @@ class TypedStorage:
|
||||
"from_buffer: 'device' can only be specified in "
|
||||
"UntypedStorage.from_buffer and TypedStorage.from_buffer"))
|
||||
|
||||
dtype = cls.dtype
|
||||
dtype = cls._dtype
|
||||
untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
|
||||
return TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
|
||||
return TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=dtype,
|
||||
_internal=True)
|
||||
|
||||
def _to(self, dtype):
|
||||
if not isinstance(dtype, torch.dtype):
|
||||
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
|
||||
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
|
||||
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage()
|
||||
if storage.data_ptr() == self.data_ptr():
|
||||
storage = storage.clone()
|
||||
return storage
|
||||
|
||||
def double(self):
|
||||
"""Casts this storage to double type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.double)
|
||||
|
||||
def float(self):
|
||||
"""Casts this storage to float type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.float)
|
||||
|
||||
def half(self):
|
||||
"""Casts this storage to half type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.half)
|
||||
|
||||
def long(self):
|
||||
"""Casts this storage to long type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.long)
|
||||
|
||||
def int(self):
|
||||
"""Casts this storage to int type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.int)
|
||||
|
||||
def short(self):
|
||||
"""Casts this storage to short type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.short)
|
||||
|
||||
def char(self):
|
||||
"""Casts this storage to char type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.int8)
|
||||
|
||||
def byte(self):
|
||||
"""Casts this storage to byte type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.uint8)
|
||||
|
||||
def bool(self):
|
||||
"""Casts this storage to bool type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.bool)
|
||||
|
||||
def bfloat16(self):
|
||||
"""Casts this storage to bfloat16 type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.bfloat16)
|
||||
|
||||
def complex_double(self):
|
||||
"""Casts this storage to complex double type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.cdouble)
|
||||
|
||||
def complex_float(self):
|
||||
"""Casts this storage to complex float type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.cfloat)
|
||||
|
||||
@classmethod
|
||||
@ -773,6 +888,7 @@ class TypedStorage:
|
||||
shared (bool): whether to share memory
|
||||
size (int): number of elements in the storage
|
||||
"""
|
||||
_warn_typed_storage_removal()
|
||||
if cls == TypedStorage:
|
||||
raise RuntimeError('from_file can only be called on derived classes')
|
||||
untyped_storage: UntypedStorage = UntypedStorage.from_file(
|
||||
@ -787,33 +903,39 @@ class TypedStorage:
|
||||
return UntypedStorage._expired(*args, **kwargs)
|
||||
|
||||
def is_pinned(self):
|
||||
return self._storage.is_pinned()
|
||||
_warn_typed_storage_removal()
|
||||
return self._untyped_storage.is_pinned()
|
||||
|
||||
def _write_file(self, *args, **kwargs):
|
||||
return self._storage._write_file(*args, **kwargs)
|
||||
return self._untyped_storage._write_file(*args, **kwargs)
|
||||
|
||||
def _set_from_file(self, *args, **kwargs):
|
||||
return self._storage._set_from_file(*args, **kwargs)
|
||||
return self._untyped_storage._set_from_file(*args, **kwargs)
|
||||
|
||||
def _set_cdata(self, *args, **kwargs):
|
||||
return self._storage._set_cdata(*args, **kwargs)
|
||||
return self._untyped_storage._set_cdata(*args, **kwargs)
|
||||
|
||||
def _share_cuda_(self, *args, **kwargs):
|
||||
return self._storage._share_cuda_(*args, **kwargs)
|
||||
return self._untyped_storage._share_cuda_(*args, **kwargs)
|
||||
|
||||
def is_shared(self):
|
||||
return self._storage.is_shared()
|
||||
_warn_typed_storage_removal()
|
||||
return self._is_shared()
|
||||
|
||||
# For internal use only, to avoid deprecation warning
|
||||
def _is_shared(self):
|
||||
return self._untyped_storage.is_shared()
|
||||
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs):
|
||||
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
|
||||
|
||||
def _share_filename_cpu_(self, *args, **kwargs):
|
||||
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
|
||||
return manager_handle, storage_handle, size // self.element_size()
|
||||
manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
|
||||
return manager_handle, storage_handle, size // self._element_size()
|
||||
|
||||
def _shared_decref(self):
|
||||
self._storage._shared_decref()
|
||||
self._untyped_storage._shared_decref()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@ -821,11 +943,11 @@ class TypedStorage:
|
||||
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
|
||||
def _shared_incref(self, *args, **kwargs):
|
||||
return self._storage._shared_incref(*args, **kwargs)
|
||||
return self._untyped_storage._shared_incref(*args, **kwargs)
|
||||
|
||||
def _share_fd_cpu_(self, *args, **kwargs):
|
||||
fd, size = self._storage._share_fd_cpu_(*args, **kwargs)
|
||||
return fd, size // self.element_size()
|
||||
fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
|
||||
return fd, size // self._element_size()
|
||||
|
||||
def _get_legacy_storage_class(self):
|
||||
if self.dtype not in _dtype_to_storage_type_map():
|
||||
@ -859,7 +981,7 @@ class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
|
||||
@classmethod
|
||||
def _new_shared(cls, size):
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * cls().element_size())
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
|
||||
return cls(wrap_storage=untyped_storage)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -927,9 +927,34 @@ def originate_pairs(
|
||||
Returns:
|
||||
(List[Pair]): Originated pairs.
|
||||
"""
|
||||
if (
|
||||
isinstance(actual, torch.TypedStorage)
|
||||
and isinstance(expected, torch.TypedStorage)
|
||||
):
|
||||
actual_len = actual._size()
|
||||
expected_len = expected._size()
|
||||
if actual_len != expected_len:
|
||||
raise ErrorMeta(
|
||||
AssertionError, f"The length of the sequences mismatch: {actual_len} != {expected_len}", id=id
|
||||
)
|
||||
|
||||
pairs = []
|
||||
for idx in range(actual_len):
|
||||
pairs.extend(
|
||||
originate_pairs(
|
||||
actual._getitem(idx),
|
||||
expected._getitem(idx),
|
||||
pair_types=pair_types,
|
||||
sequence_types=sequence_types,
|
||||
mapping_types=mapping_types,
|
||||
id=(*id, idx),
|
||||
**options,
|
||||
)
|
||||
)
|
||||
return pairs
|
||||
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
|
||||
# "a" == "a"[0][0]...
|
||||
if (
|
||||
elif (
|
||||
isinstance(actual, sequence_types)
|
||||
and not isinstance(actual, str)
|
||||
and isinstance(expected, sequence_types)
|
||||
|
||||
@ -47,7 +47,7 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
before.size() == after.size() and
|
||||
torch.allclose(before, after, equal_nan=True) and
|
||||
md[0] == after.stride() and
|
||||
md[1] == after.storage()._cdata
|
||||
md[1] == after._typed_storage()._cdata
|
||||
)
|
||||
return False
|
||||
|
||||
@ -76,12 +76,12 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
if not type(e) == torch.Tensor:
|
||||
try:
|
||||
current = e.elem
|
||||
return (deepcopy(current.stride()), current.storage()._cdata)
|
||||
return (deepcopy(current.stride()), current._typed_storage()._cdata)
|
||||
except AttributeError as t:
|
||||
return None
|
||||
# Sparse CSR tensors do not have strides or storage
|
||||
elif (e.layout != torch.sparse_csr):
|
||||
return (deepcopy(e.stride()), e.storage()._cdata)
|
||||
return (deepcopy(e.stride()), e._typed_storage()._cdata)
|
||||
return None
|
||||
|
||||
self.ops.append(func._schema.name)
|
||||
|
||||
@ -391,7 +391,7 @@ def _inflate_expr(
|
||||
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# Small-storage tensors can just be saved directly.
|
||||
if arg.storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
|
||||
if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
|
||||
return arg, ref, None
|
||||
# Small contiguous tensors can be cloned to have small storage.
|
||||
# TODO: Should we do this even for non-contiguous tensors?
|
||||
@ -407,7 +407,7 @@ def _inflate_expr(
|
||||
# TODO: Provide more useful diagnostics.
|
||||
raise Exception(
|
||||
f"Bundled input argument at position '{ref}' is "
|
||||
f"a tensor with storage size {arg.storage().size()}. "
|
||||
f"a tensor with storage size {arg._typed_storage().size()}. "
|
||||
f"You probably don't want to bundle this as an input. "
|
||||
)
|
||||
else:
|
||||
|
||||
@ -158,7 +158,7 @@ def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum(x.numel() for x in batch)
|
||||
storage = elem.storage()._new_shared(numel, device=elem.device)
|
||||
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
|
||||
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
||||
return torch.stack(batch, 0, out=out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user