mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dim_order] raised runtime error when tensor has ambiguous dim order (#141632)
This diff makes tensor.dim_order() raise error when tensor's dim order is ambiguous. Detail discussion can be found https://fb.workplace.com/groups/894363187646754/permalink/2039987243084337/ Differential Revision: [D65133579](https://our.internmc.facebook.com/intern/diff/D65133579/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141632 Approved by: https://github.com/larryliu0820
This commit is contained in:
committed by
PyTorch MergeBot
parent
e1196dfe51
commit
29e985b7b0
@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
|
||||
TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
|
||||
skipIfCrossRef, TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
|
||||
skipCUDAMemoryLeakCheckIf, BytesIOContext,
|
||||
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
||||
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
|
||||
@ -8693,11 +8693,13 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
|
||||
test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)
|
||||
|
||||
@skipIfCrossRef
|
||||
def test_dim_order(self):
|
||||
shape = (2, 3, 5, 7)
|
||||
|
||||
t = torch.empty(shape)
|
||||
self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
|
||||
self.assertSequenceEqual(t.dim_order(ambiguity_check=True), (0, 1, 2, 3), seq_type=tuple)
|
||||
# transpose doesn't really change the underlying physical memory
|
||||
# so expecting dim_order change to reflect that (like strides)
|
||||
self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))
|
||||
@ -8713,15 +8715,36 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
dim_order, torch.empty_permuted(shape, dim_order).dim_order()
|
||||
)
|
||||
|
||||
for shape in [(2, 2, 2, 2), (2, 1, 2, 2), (2, 2, 1, 2), (2, 2, 2, 1), (2, 2, 1, 1), (2, 1, 1, 2)]:
|
||||
target_shapes = [[2, 2, 1, 2], [1, 2, 2, 2], [2, 2, 2, 1], [1, 2, 2, 1], [1, 2, 1, 2]]
|
||||
|
||||
for shape in target_shapes:
|
||||
for memory_format in (torch.contiguous_format, torch.channels_last):
|
||||
t = torch.empty(shape).to(memory_format=memory_format)
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.dim_order(ambiguity_check=True)
|
||||
|
||||
if memory_format == torch.contiguous_format:
|
||||
dim_order_target = list(range(len(shape)))
|
||||
elif memory_format == torch.channels_last:
|
||||
dim_order_target = [0, *list(range(2, len(shape))), 1]
|
||||
|
||||
self.assertSequenceEqual(dim_order_target, t.dim_order())
|
||||
self.assertSequenceEqual(
|
||||
dim_order_target, t.dim_order(ambiguity_check=[torch.contiguous_format, torch.channels_last])
|
||||
)
|
||||
|
||||
|
||||
ambiguous_shapes = [[2, 1, 2, 2], [2, 2, 1, 1], [1, 2, 1, 1], [2, 1, 1, 2], [2, 1, 2, 1],
|
||||
[1, 1, 1, 2], [1, 1, 2, 2], [1, 1, 1, 1], [2, 1, 1, 1], [1, 1, 2, 1]]
|
||||
|
||||
for shape in ambiguous_shapes:
|
||||
for memory_format in (torch.contiguous_format, torch.channels_last):
|
||||
t = torch.empty(shape).to(memory_format=memory_format)
|
||||
with self.assertRaises(RuntimeError):
|
||||
t.dim_order(ambiguity_check=True)
|
||||
t.dim_order(ambiguity_check=[torch.contiguous_format, torch.channels_last])
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL_STR")
|
||||
|
||||
def test_subclass_tensors(self):
|
||||
# raise an error when trying to subclass FloatTensor
|
||||
|
115
torch/_tensor.py
115
torch/_tensor.py
@ -6,7 +6,7 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
@ -1490,32 +1490,123 @@ class Tensor(torch._C.TensorBase):
|
||||
"""
|
||||
return self.to_sparse()
|
||||
|
||||
def dim_order(self):
|
||||
def dim_order(
|
||||
self, *, ambiguity_check: Union[bool, List[torch.memory_format]] = False
|
||||
):
|
||||
"""
|
||||
dim_order(ambiguity_check=False) -> tuple
|
||||
|
||||
dim_order() -> tuple
|
||||
Returns the uniquely determined tuple of int describing the dim order or
|
||||
physical layout of :attr:`self`.
|
||||
|
||||
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Dim order represents how dimensions are laid out in memory,
|
||||
The dim order represents how dimensions are laid out in memory,
|
||||
starting from the outermost to the innermost dimension.
|
||||
|
||||
Example::
|
||||
Note that the dim order may not always be uniquely determined.
|
||||
If `ambiguity_check` is True, this function raises a RuntimeError when the dim order cannot be uniquely determined;
|
||||
If `ambiguity_check` is a list of memory formats, this function raises a RuntimeError when tensor can not be interpreted
|
||||
into exactly one of the given memory formats, or it cannot be uniquely determined.
|
||||
If `ambiguity_check` is False, it will return one of legal dim order(s) without checking its uniqueness.
|
||||
Otherwise, it will raise TypeError.
|
||||
|
||||
Args:
|
||||
ambiguity_check (bool or List[torch.memory_format]): The check method for ambiguity of dim order.
|
||||
|
||||
>>> torch.empty((2, 3, 5, 7)).dim_order()
|
||||
(0, 1, 2, 3)
|
||||
>>> torch.empty((2, 3, 5, 7)).transpose(1, 2).dim_order()
|
||||
(0, 2, 1, 3)
|
||||
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
|
||||
(0, 2, 3, 1)
|
||||
|
||||
>>> torch.empty((1, 2, 3, 4)).dim_order()
|
||||
(0, 1, 2, 3)
|
||||
>>> try:
|
||||
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check=True)
|
||||
... except RuntimeError as e:
|
||||
... print(e)
|
||||
The tensor does not have unique dim order, or cannot map to exact one of the given memory formats.
|
||||
>>> torch.empty((1, 2, 3, 4)).dim_order(
|
||||
... ambiguity_check=[torch.contiguous_format, torch.channels_last]
|
||||
... ) # It can be mapped to contiguous format
|
||||
(0, 1, 2, 3)
|
||||
>>> try:
|
||||
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL")
|
||||
... except TypeError as e:
|
||||
... print(e)
|
||||
The ambiguity_check argument must be a bool or a list of memory formats.
|
||||
.. warning::
|
||||
The dim_order tensor API is experimental and subject to change.
|
||||
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.dim_order, (self,), self)
|
||||
|
||||
# Sanity check ambiguity_check data types
|
||||
if not isinstance(ambiguity_check, bool):
|
||||
if not isinstance(ambiguity_check, list):
|
||||
raise TypeError(
|
||||
"The ambiguity_check argument must be a bool or a list of memory formats."
|
||||
)
|
||||
for memory_format in ambiguity_check:
|
||||
if not isinstance(memory_format, torch.memory_format):
|
||||
raise TypeError(
|
||||
"The ambiguity_check argument must be a bool or a list of memory formats."
|
||||
)
|
||||
|
||||
def invalid_unique_memory_format(tensor, valid_memory_formats):
|
||||
"""
|
||||
Returns True if the tensor cannot be uniquely mapped to any of the given memory formats, False otherwise.
|
||||
"""
|
||||
|
||||
n_legality = 0
|
||||
|
||||
for memory_format in valid_memory_formats:
|
||||
if tensor.is_contiguous(memory_format=memory_format):
|
||||
n_legality += 1
|
||||
|
||||
return n_legality != 1
|
||||
|
||||
def has_multiple_dim_order(tensor):
|
||||
"""
|
||||
Returns True if there're multiple legal dim orders for given tensor, False otherwise.
|
||||
|
||||
The tensor is considered to have multiple legal dim orders if either of the following conditions is met:
|
||||
|
||||
* Singleton Dimensions: There's at least one singleteon dimension in the tensor.
|
||||
Since their size is 1, they don't affect the memory offset (stride * index
|
||||
is zero because index is always zero). Therefore, they can be placed anywhere
|
||||
in the dimension order without changing how data is accessed.
|
||||
* Same strides: Strides reflect how the tensor is stored in memory.
|
||||
If any two dimensions have the same stride, swapping these dimensions won't
|
||||
change how data is accessed, leading to multiple correct dimension orders.
|
||||
"""
|
||||
|
||||
sizes = tensor.size()
|
||||
strides = tensor.stride()
|
||||
|
||||
# Check if there are any duplicate strides
|
||||
has_duplicate_strides = any(
|
||||
earlier == later for earlier, later in zip(strides, strides[1:])
|
||||
)
|
||||
|
||||
# Check if there are any singleton dimensions
|
||||
has_singleton_dims = any(size == 1 for size in sizes)
|
||||
|
||||
return has_duplicate_strides or has_singleton_dims
|
||||
|
||||
valid_memory_formats = (
|
||||
ambiguity_check if isinstance(ambiguity_check, list) else []
|
||||
)
|
||||
check_multiple_dim_order = (
|
||||
ambiguity_check if isinstance(ambiguity_check, bool) else True
|
||||
)
|
||||
|
||||
if (
|
||||
check_multiple_dim_order and has_multiple_dim_order(self)
|
||||
) and invalid_unique_memory_format(self, valid_memory_formats):
|
||||
raise RuntimeError(
|
||||
"The tensor does not have unique dim order, or cannot map to exact one of the given memory formats."
|
||||
)
|
||||
|
||||
import torch._prims_common as utils
|
||||
|
||||
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
|
||||
|
@ -1417,7 +1417,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.dense_dim: lambda self: -1,
|
||||
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
|
||||
Tensor.dim: lambda self: -1,
|
||||
Tensor.dim_order: lambda self: -1,
|
||||
Tensor.dim_order: lambda self, ambiguity_check=False: -1,
|
||||
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
|
||||
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
|
||||
Tensor.element_size: lambda self: -1,
|
||||
|
Reference in New Issue
Block a user