Rename Tensor._storage to Tensor.untyped_storage and update docs (#91414)

Fixes #89224

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91414
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-12-28 19:21:34 +00:00
committed by PyTorch MergeBot
parent 5b223c43ec
commit 08a47549af
14 changed files with 45 additions and 29 deletions

View File

@ -637,6 +637,7 @@ Tensor class reference
Tensor.std
Tensor.stft
Tensor.storage
Tensor.untyped_storage
Tensor.storage_offset
Tensor.storage_type
Tensor.stride

View File

@ -211,7 +211,7 @@ class LinearMixedPrecision(nn.Module):
# Shard is never allocated if param_dtype mixed precision is not
# enabled.
if mp_config.param_dtype is not None:
cls.assertEqual(0, param._mp_shard._storage().size())
cls.assertEqual(0, param._mp_shard.untyped_storage().size())
else:
cls.assertFalse(hasattr(param, "_mp_shard"))
elif param_is_sharded:
@ -274,7 +274,7 @@ class TestFSDPMixedPrecision(FSDPTest):
fsdp_units = FSDP.fsdp_modules(fsdp_model)
for fsdp in fsdp_units:
for param in fsdp.params:
self.assertEqual(0, param._mp_shard._storage().size())
self.assertEqual(0, param._mp_shard.untyped_storage().size())
def _reduce_scatter_validate_mp(
self, orig_reduce_scatter, mp_config, *args, **kwargs

View File

@ -875,7 +875,7 @@ class TestSparseCompressed(TestCase):
base.device != other.device)):
return False
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base._storage().data_ptr() != other._storage().data_ptr():
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
return False
return True

View File

@ -102,7 +102,7 @@ class TestViewOps(TestCase):
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'cpu' or base.device.type == 'cuda':
if base._storage().data_ptr() != other._storage().data_ptr():
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
return False
return True

View File

@ -963,7 +963,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
{
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
return handle_torch_function(self, "_storage");
return handle_torch_function(self, "untyped_storage");
}
auto& self_ = THPVariable_Unpack(self);
return createPyObject(self_.storage());
@ -1280,7 +1280,7 @@ PyMethodDef variable_methods[] = {
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
{"untyped_storage", THPVariable_storage, METH_NOARGS, NULL},
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},
{"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, NULL},

View File

@ -665,7 +665,7 @@ def gen_pyi(
"map2_": [
"def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
],
"storage": ["def _storage(self) -> Storage: ..."],
"storage": ["def untyped_storage(self) -> Storage: ..."],
"storage_type": ["def storage_type(self) -> Storage: ..."],
"type": [
"def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...",

View File

@ -161,7 +161,7 @@ class DDPOptimizer:
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p._storage().nbytes()
buckets[0].size += p.untyped_storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
@ -169,7 +169,7 @@ class DDPOptimizer:
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].size += maybe_param.untyped_storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))

View File

@ -353,7 +353,7 @@ def run_functionalized_fw_and_collect_metadata(f):
# _x_updated_metadata = CompiledFunction.fw_metadata.metadata_mutation_input_info[0]
# x.as_strided_(_x_updated_metadata.size(), _x_updated_metadata.stride(), _x_updated_metadata.storage_offset())
# return out
if StorageWeakRef(arg._storage()) == StorageWeakRef(new_arg._storage()):
if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
# We can use the storage aliasing of the inputs and updated inputs
# to detect when an input was actually updated, or just inplace-viewed.
collect_mutated_input_info.append(MutationType.metadata_only)
@ -429,8 +429,10 @@ def run_functionalized_fw_and_collect_metadata(f):
# This will be more complicated when you have multiple _base tensors aliasing the same
# underlying storage, when we eventually handle that.
# We'll need to ensure that we generate the view off of the right base.
inp_storage_refs = {
StorageWeakRef(inpt._storage()): idx for idx, inpt in enumerate(flat_f_args) if isinstance(inpt, torch.Tensor)}
inp_storage_refs = {}
for idx, inpt in enumerate(flat_f_args):
if isinstance(inpt, torch.Tensor):
inp_storage_refs[StorageWeakRef(inpt.untyped_storage())] = idx
inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)}
inp_storage_refs_set = set(inp_storage_refs)
@ -453,8 +455,8 @@ def run_functionalized_fw_and_collect_metadata(f):
# Note: When detecting input/output aliasing, we NEED to do it using the outer FunctionalTensorWrapper objects.
# In the case where we mutate an input *and* return a view of it, the outer wrappers will still alias,
# but the inner tensors no longer alias.
if isinstance(o, torch.Tensor) and StorageWeakRef(o._storage()) in inp_storage_refs:
aliased_inp_idx = inp_storage_refs[StorageWeakRef(o._storage())]
if isinstance(o, torch.Tensor) and StorageWeakRef(o.untyped_storage()) in inp_storage_refs:
aliased_inp_idx = inp_storage_refs[StorageWeakRef(o.untyped_storage())]
is_exact_input = id(o) in inp_tensor_ids
aliases_intermediate_and_not_input = False
aliased_out_idx[o] = (
@ -1071,7 +1073,7 @@ def merge_view_inputs(
other_args = []
for i, inpt in enumerate(fwd_inputs):
if isinstance(inpt, Tensor):
storage_ref = StorageWeakRef(inpt._storage())
storage_ref = StorageWeakRef(inpt.untyped_storage())
storage_ref_to_idx[storage_ref].append(i)
else:
other_args.append(inpt)
@ -1118,7 +1120,7 @@ def merge_view_inputs(
if len(non_none_bases) == 0:
# Case where none of the aliases require gradients
example_idx = aliased_input_indices[0]
synthetic_base = torch.Tensor(fwd_inputs[example_idx]._storage())
synthetic_base = torch.Tensor(fwd_inputs[example_idx].untyped_storage())
else:
# Case where all of the aliases require gradients, and have the same _base.
synthetic_base = non_none_bases[0]

View File

@ -149,7 +149,7 @@ class MetaConverter:
if swr not in self.storage_memo:
self.storage_memo[swr] = callback(
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
)._storage()
).untyped_storage()
return self.storage_memo[swr]
# This function assumes that it's possible to do the conversion
@ -374,7 +374,7 @@ class MetaConverter:
# format here
r = r.clone(memory_format=torch.preserve_format)
s = t._storage()
s = t.untyped_storage()
swr = StorageWeakRef(s)
if (
swr not in self.storage_memo
@ -382,7 +382,7 @@ class MetaConverter:
and r.storage_offset() == storage_offset
):
# You're normal and happy, install the fresh storage into the memo
self.storage_memo[swr] = r._storage()
self.storage_memo[swr] = r.untyped_storage()
else:
# You're in crazy town; somehow you gave us a tensor
# that wasn't a view, but had nonzero storage offset,

View File

@ -218,9 +218,15 @@ class Tensor(torch._C._TensorBase):
def storage(self):
r"""
storage() -> torch.Storage
storage() -> torch.TypedStorage
Returns the underlying storage.
Returns the underlying :class:`TypedStorage`.
.. warning::
:class:`TypedStorage` is deprecated. It will be removed in the future, and
:class:`UntypedStorage` will be the only storage class. To access the
:class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.storage, (self,), self)
@ -230,11 +236,9 @@ class Tensor(torch._C._TensorBase):
# For internal use only, to avoid raising deprecation warning
def _typed_storage(self):
_storage = self._storage()
if isinstance(_storage, torch.TypedStorage):
_storage = _storage._untyped_storage
untyped_storage = self.untyped_storage()
return torch.TypedStorage(
wrap_storage=_storage, dtype=self.dtype, _internal=True
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
)
def _reduce_ex_internal(self, proto):

View File

@ -4810,6 +4810,15 @@ Example::
""",
)
add_docstr_all(
"untyped_storage",
r"""
untyped_storage() -> torch.UntypedStorage
Returns the underlying :class:`UntypedStorage`.
""",
)
add_docstr_all(
"stride",
r"""

View File

@ -131,8 +131,8 @@ at::Storage createStorageGetType(
if (is_typed_storage) {
// NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
// `_storage`, so we must decrement them. The refcounts will still stay
// nonzero since the `TypedStorage` maintains a reference.
// `_untyped_storage`, so we must decrement them. The refcounts will still
// stay nonzero since the `TypedStorage` maintains a reference.
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
TORCH_INTERNAL_ASSERT(dtype_obj);
Py_DECREF(dtype_obj);

View File

@ -1317,7 +1317,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
Tensor.storage: lambda self: -1,
Tensor._storage: lambda self: -1,
Tensor.untyped_storage: lambda self: -1,
Tensor.storage_offset: lambda self: -1,
Tensor.storage_type: lambda self: -1,
Tensor.sum_to_size: lambda self, size: -1,

View File

@ -311,7 +311,7 @@ def _warn_typed_storage_removal(stacklevel=2):
"TypedStorage is deprecated. It will be removed in the future and "
"UntypedStorage will be the only storage class. This should only matter "
"to you if you are using storages directly. To access UntypedStorage "
"directly, use tensor._storage() instead of tensor.storage()"
"directly, use tensor.untyped_storage() instead of tensor.storage()"
)
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)