mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename Tensor._storage
to Tensor.untyped_storage
and update docs (#91414)
Fixes #89224 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91414 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b223c43ec
commit
08a47549af
@ -637,6 +637,7 @@ Tensor class reference
|
||||
Tensor.std
|
||||
Tensor.stft
|
||||
Tensor.storage
|
||||
Tensor.untyped_storage
|
||||
Tensor.storage_offset
|
||||
Tensor.storage_type
|
||||
Tensor.stride
|
||||
|
@ -211,7 +211,7 @@ class LinearMixedPrecision(nn.Module):
|
||||
# Shard is never allocated if param_dtype mixed precision is not
|
||||
# enabled.
|
||||
if mp_config.param_dtype is not None:
|
||||
cls.assertEqual(0, param._mp_shard._storage().size())
|
||||
cls.assertEqual(0, param._mp_shard.untyped_storage().size())
|
||||
else:
|
||||
cls.assertFalse(hasattr(param, "_mp_shard"))
|
||||
elif param_is_sharded:
|
||||
@ -274,7 +274,7 @@ class TestFSDPMixedPrecision(FSDPTest):
|
||||
fsdp_units = FSDP.fsdp_modules(fsdp_model)
|
||||
for fsdp in fsdp_units:
|
||||
for param in fsdp.params:
|
||||
self.assertEqual(0, param._mp_shard._storage().size())
|
||||
self.assertEqual(0, param._mp_shard.untyped_storage().size())
|
||||
|
||||
def _reduce_scatter_validate_mp(
|
||||
self, orig_reduce_scatter, mp_config, *args, **kwargs
|
||||
|
@ -875,7 +875,7 @@ class TestSparseCompressed(TestCase):
|
||||
base.device != other.device)):
|
||||
return False
|
||||
if base.device.type == 'cpu' or base.device.type == 'cuda':
|
||||
if base._storage().data_ptr() != other._storage().data_ptr():
|
||||
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -102,7 +102,7 @@ class TestViewOps(TestCase):
|
||||
# Note: only validates storage on native device types
|
||||
# because some accelerators, like XLA, do not expose storage
|
||||
if base.device.type == 'cpu' or base.device.type == 'cuda':
|
||||
if base._storage().data_ptr() != other._storage().data_ptr():
|
||||
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -963,7 +963,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function(self)) {
|
||||
return handle_torch_function(self, "_storage");
|
||||
return handle_torch_function(self, "untyped_storage");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return createPyObject(self_.storage());
|
||||
@ -1280,7 +1280,7 @@ PyMethodDef variable_methods[] = {
|
||||
{"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"_storage", THPVariable_storage, METH_NOARGS, NULL},
|
||||
{"untyped_storage", THPVariable_storage, METH_NOARGS, NULL},
|
||||
{"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL},
|
||||
{"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
@ -665,7 +665,7 @@ def gen_pyi(
|
||||
"map2_": [
|
||||
"def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..."
|
||||
],
|
||||
"storage": ["def _storage(self) -> Storage: ..."],
|
||||
"storage": ["def untyped_storage(self) -> Storage: ..."],
|
||||
"storage_type": ["def storage_type(self) -> Storage: ..."],
|
||||
"type": [
|
||||
"def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...",
|
||||
|
@ -161,7 +161,7 @@ class DDPOptimizer:
|
||||
for name, p in target.named_parameters():
|
||||
param = target.get_parameter(name)
|
||||
if p.requires_grad and not self._ignore_parameter(param):
|
||||
buckets[0].size += p._storage().nbytes()
|
||||
buckets[0].size += p.untyped_storage().nbytes()
|
||||
buckets[0].params.append(f"{node.target}_{name}")
|
||||
buckets[0].param_ids.append(id(param))
|
||||
elif node.op == "get_attr":
|
||||
@ -169,7 +169,7 @@ class DDPOptimizer:
|
||||
if maybe_param.requires_grad and not self._ignore_parameter(
|
||||
maybe_param
|
||||
):
|
||||
buckets[0].size += maybe_param._storage().nbytes()
|
||||
buckets[0].size += maybe_param.untyped_storage().nbytes()
|
||||
buckets[0].params.append(node.target)
|
||||
buckets[0].param_ids.append(id(maybe_param))
|
||||
|
||||
|
@ -353,7 +353,7 @@ def run_functionalized_fw_and_collect_metadata(f):
|
||||
# _x_updated_metadata = CompiledFunction.fw_metadata.metadata_mutation_input_info[0]
|
||||
# x.as_strided_(_x_updated_metadata.size(), _x_updated_metadata.stride(), _x_updated_metadata.storage_offset())
|
||||
# return out
|
||||
if StorageWeakRef(arg._storage()) == StorageWeakRef(new_arg._storage()):
|
||||
if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
|
||||
# We can use the storage aliasing of the inputs and updated inputs
|
||||
# to detect when an input was actually updated, or just inplace-viewed.
|
||||
collect_mutated_input_info.append(MutationType.metadata_only)
|
||||
@ -429,8 +429,10 @@ def run_functionalized_fw_and_collect_metadata(f):
|
||||
# This will be more complicated when you have multiple _base tensors aliasing the same
|
||||
# underlying storage, when we eventually handle that.
|
||||
# We'll need to ensure that we generate the view off of the right base.
|
||||
inp_storage_refs = {
|
||||
StorageWeakRef(inpt._storage()): idx for idx, inpt in enumerate(flat_f_args) if isinstance(inpt, torch.Tensor)}
|
||||
inp_storage_refs = {}
|
||||
for idx, inpt in enumerate(flat_f_args):
|
||||
if isinstance(inpt, torch.Tensor):
|
||||
inp_storage_refs[StorageWeakRef(inpt.untyped_storage())] = idx
|
||||
inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)}
|
||||
inp_storage_refs_set = set(inp_storage_refs)
|
||||
|
||||
@ -453,8 +455,8 @@ def run_functionalized_fw_and_collect_metadata(f):
|
||||
# Note: When detecting input/output aliasing, we NEED to do it using the outer FunctionalTensorWrapper objects.
|
||||
# In the case where we mutate an input *and* return a view of it, the outer wrappers will still alias,
|
||||
# but the inner tensors no longer alias.
|
||||
if isinstance(o, torch.Tensor) and StorageWeakRef(o._storage()) in inp_storage_refs:
|
||||
aliased_inp_idx = inp_storage_refs[StorageWeakRef(o._storage())]
|
||||
if isinstance(o, torch.Tensor) and StorageWeakRef(o.untyped_storage()) in inp_storage_refs:
|
||||
aliased_inp_idx = inp_storage_refs[StorageWeakRef(o.untyped_storage())]
|
||||
is_exact_input = id(o) in inp_tensor_ids
|
||||
aliases_intermediate_and_not_input = False
|
||||
aliased_out_idx[o] = (
|
||||
@ -1071,7 +1073,7 @@ def merge_view_inputs(
|
||||
other_args = []
|
||||
for i, inpt in enumerate(fwd_inputs):
|
||||
if isinstance(inpt, Tensor):
|
||||
storage_ref = StorageWeakRef(inpt._storage())
|
||||
storage_ref = StorageWeakRef(inpt.untyped_storage())
|
||||
storage_ref_to_idx[storage_ref].append(i)
|
||||
else:
|
||||
other_args.append(inpt)
|
||||
@ -1118,7 +1120,7 @@ def merge_view_inputs(
|
||||
if len(non_none_bases) == 0:
|
||||
# Case where none of the aliases require gradients
|
||||
example_idx = aliased_input_indices[0]
|
||||
synthetic_base = torch.Tensor(fwd_inputs[example_idx]._storage())
|
||||
synthetic_base = torch.Tensor(fwd_inputs[example_idx].untyped_storage())
|
||||
else:
|
||||
# Case where all of the aliases require gradients, and have the same _base.
|
||||
synthetic_base = non_none_bases[0]
|
||||
|
@ -149,7 +149,7 @@ class MetaConverter:
|
||||
if swr not in self.storage_memo:
|
||||
self.storage_memo[swr] = callback(
|
||||
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
|
||||
)._storage()
|
||||
).untyped_storage()
|
||||
return self.storage_memo[swr]
|
||||
|
||||
# This function assumes that it's possible to do the conversion
|
||||
@ -374,7 +374,7 @@ class MetaConverter:
|
||||
# format here
|
||||
r = r.clone(memory_format=torch.preserve_format)
|
||||
|
||||
s = t._storage()
|
||||
s = t.untyped_storage()
|
||||
swr = StorageWeakRef(s)
|
||||
if (
|
||||
swr not in self.storage_memo
|
||||
@ -382,7 +382,7 @@ class MetaConverter:
|
||||
and r.storage_offset() == storage_offset
|
||||
):
|
||||
# You're normal and happy, install the fresh storage into the memo
|
||||
self.storage_memo[swr] = r._storage()
|
||||
self.storage_memo[swr] = r.untyped_storage()
|
||||
else:
|
||||
# You're in crazy town; somehow you gave us a tensor
|
||||
# that wasn't a view, but had nonzero storage offset,
|
||||
|
@ -218,9 +218,15 @@ class Tensor(torch._C._TensorBase):
|
||||
|
||||
def storage(self):
|
||||
r"""
|
||||
storage() -> torch.Storage
|
||||
storage() -> torch.TypedStorage
|
||||
|
||||
Returns the underlying storage.
|
||||
Returns the underlying :class:`TypedStorage`.
|
||||
|
||||
.. warning::
|
||||
|
||||
:class:`TypedStorage` is deprecated. It will be removed in the future, and
|
||||
:class:`UntypedStorage` will be the only storage class. To access the
|
||||
:class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.storage, (self,), self)
|
||||
@ -230,11 +236,9 @@ class Tensor(torch._C._TensorBase):
|
||||
|
||||
# For internal use only, to avoid raising deprecation warning
|
||||
def _typed_storage(self):
|
||||
_storage = self._storage()
|
||||
if isinstance(_storage, torch.TypedStorage):
|
||||
_storage = _storage._untyped_storage
|
||||
untyped_storage = self.untyped_storage()
|
||||
return torch.TypedStorage(
|
||||
wrap_storage=_storage, dtype=self.dtype, _internal=True
|
||||
wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
|
||||
)
|
||||
|
||||
def _reduce_ex_internal(self, proto):
|
||||
|
@ -4810,6 +4810,15 @@ Example::
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"untyped_storage",
|
||||
r"""
|
||||
untyped_storage() -> torch.UntypedStorage
|
||||
|
||||
Returns the underlying :class:`UntypedStorage`.
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"stride",
|
||||
r"""
|
||||
|
@ -131,8 +131,8 @@ at::Storage createStorageGetType(
|
||||
|
||||
if (is_typed_storage) {
|
||||
// NOTE: `PyObject_GetAttrString` increments the refcounts to `dtype` and
|
||||
// `_storage`, so we must decrement them. The refcounts will still stay
|
||||
// nonzero since the `TypedStorage` maintains a reference.
|
||||
// `_untyped_storage`, so we must decrement them. The refcounts will still
|
||||
// stay nonzero since the `TypedStorage` maintains a reference.
|
||||
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
||||
TORCH_INTERNAL_ASSERT(dtype_obj);
|
||||
Py_DECREF(dtype_obj);
|
||||
|
@ -1317,7 +1317,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
|
||||
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
|
||||
Tensor.storage: lambda self: -1,
|
||||
Tensor._storage: lambda self: -1,
|
||||
Tensor.untyped_storage: lambda self: -1,
|
||||
Tensor.storage_offset: lambda self: -1,
|
||||
Tensor.storage_type: lambda self: -1,
|
||||
Tensor.sum_to_size: lambda self, size: -1,
|
||||
|
@ -311,7 +311,7 @@ def _warn_typed_storage_removal(stacklevel=2):
|
||||
"TypedStorage is deprecated. It will be removed in the future and "
|
||||
"UntypedStorage will be the only storage class. This should only matter "
|
||||
"to you if you are using storages directly. To access UntypedStorage "
|
||||
"directly, use tensor._storage() instead of tensor.storage()"
|
||||
"directly, use tensor.untyped_storage() instead of tensor.storage()"
|
||||
)
|
||||
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
|
||||
|
||||
|
Reference in New Issue
Block a user