[dim_order] raised runtime error when tensor has ambiguous dim order (#141632)

This diff makes tensor.dim_order() raise error when tensor's dim order is ambiguous. Detail discussion can be found https://fb.workplace.com/groups/894363187646754/permalink/2039987243084337/

Differential Revision: [D65133579](https://our.internmc.facebook.com/intern/diff/D65133579/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141632
Approved by: https://github.com/larryliu0820
This commit is contained in:
gasoonjia
2024-12-06 16:02:31 -08:00
committed by PyTorch MergeBot
parent e1196dfe51
commit 29e985b7b0
3 changed files with 130 additions and 16 deletions

View File

@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf,
TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
skipIfCrossRef, TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype,
skipCUDAMemoryLeakCheckIf, BytesIOContext,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
@ -8693,11 +8693,13 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)
@skipIfCrossRef
def test_dim_order(self):
shape = (2, 3, 5, 7)
t = torch.empty(shape)
self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
self.assertSequenceEqual(t.dim_order(ambiguity_check=True), (0, 1, 2, 3), seq_type=tuple)
# transpose doesn't really change the underlying physical memory
# so expecting dim_order change to reflect that (like strides)
self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))
@ -8713,15 +8715,36 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
dim_order, torch.empty_permuted(shape, dim_order).dim_order()
)
for shape in [(2, 2, 2, 2), (2, 1, 2, 2), (2, 2, 1, 2), (2, 2, 2, 1), (2, 2, 1, 1), (2, 1, 1, 2)]:
target_shapes = [[2, 2, 1, 2], [1, 2, 2, 2], [2, 2, 2, 1], [1, 2, 2, 1], [1, 2, 1, 2]]
for shape in target_shapes:
for memory_format in (torch.contiguous_format, torch.channels_last):
t = torch.empty(shape).to(memory_format=memory_format)
with self.assertRaises(RuntimeError):
t.dim_order(ambiguity_check=True)
if memory_format == torch.contiguous_format:
dim_order_target = list(range(len(shape)))
elif memory_format == torch.channels_last:
dim_order_target = [0, *list(range(2, len(shape))), 1]
self.assertSequenceEqual(dim_order_target, t.dim_order())
self.assertSequenceEqual(
dim_order_target, t.dim_order(ambiguity_check=[torch.contiguous_format, torch.channels_last])
)
ambiguous_shapes = [[2, 1, 2, 2], [2, 2, 1, 1], [1, 2, 1, 1], [2, 1, 1, 2], [2, 1, 2, 1],
[1, 1, 1, 2], [1, 1, 2, 2], [1, 1, 1, 1], [2, 1, 1, 1], [1, 1, 2, 1]]
for shape in ambiguous_shapes:
for memory_format in (torch.contiguous_format, torch.channels_last):
t = torch.empty(shape).to(memory_format=memory_format)
with self.assertRaises(RuntimeError):
t.dim_order(ambiguity_check=True)
t.dim_order(ambiguity_check=[torch.contiguous_format, torch.channels_last])
with self.assertRaises(TypeError):
torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL_STR")
def test_subclass_tensors(self):
# raise an error when trying to subclass FloatTensor

View File

@ -6,7 +6,7 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch._C as _C
@ -1490,32 +1490,123 @@ class Tensor(torch._C.TensorBase):
"""
return self.to_sparse()
def dim_order(self):
def dim_order(
self, *, ambiguity_check: Union[bool, List[torch.memory_format]] = False
):
"""
dim_order(ambiguity_check=False) -> tuple
dim_order() -> tuple
Returns the uniquely determined tuple of int describing the dim order or
physical layout of :attr:`self`.
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
Args:
None
Dim order represents how dimensions are laid out in memory,
The dim order represents how dimensions are laid out in memory,
starting from the outermost to the innermost dimension.
Example::
Note that the dim order may not always be uniquely determined.
If `ambiguity_check` is True, this function raises a RuntimeError when the dim order cannot be uniquely determined;
If `ambiguity_check` is a list of memory formats, this function raises a RuntimeError when tensor can not be interpreted
into exactly one of the given memory formats, or it cannot be uniquely determined.
If `ambiguity_check` is False, it will return one of legal dim order(s) without checking its uniqueness.
Otherwise, it will raise TypeError.
Args:
ambiguity_check (bool or List[torch.memory_format]): The check method for ambiguity of dim order.
>>> torch.empty((2, 3, 5, 7)).dim_order()
(0, 1, 2, 3)
>>> torch.empty((2, 3, 5, 7)).transpose(1, 2).dim_order()
(0, 2, 1, 3)
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
(0, 2, 3, 1)
>>> torch.empty((1, 2, 3, 4)).dim_order()
(0, 1, 2, 3)
>>> try:
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check=True)
... except RuntimeError as e:
... print(e)
The tensor does not have unique dim order, or cannot map to exact one of the given memory formats.
>>> torch.empty((1, 2, 3, 4)).dim_order(
... ambiguity_check=[torch.contiguous_format, torch.channels_last]
... ) # It can be mapped to contiguous format
(0, 1, 2, 3)
>>> try:
... torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL")
... except TypeError as e:
... print(e)
The ambiguity_check argument must be a bool or a list of memory formats.
.. warning::
The dim_order tensor API is experimental and subject to change.
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.dim_order, (self,), self)
# Sanity check ambiguity_check data types
if not isinstance(ambiguity_check, bool):
if not isinstance(ambiguity_check, list):
raise TypeError(
"The ambiguity_check argument must be a bool or a list of memory formats."
)
for memory_format in ambiguity_check:
if not isinstance(memory_format, torch.memory_format):
raise TypeError(
"The ambiguity_check argument must be a bool or a list of memory formats."
)
def invalid_unique_memory_format(tensor, valid_memory_formats):
"""
Returns True if the tensor cannot be uniquely mapped to any of the given memory formats, False otherwise.
"""
n_legality = 0
for memory_format in valid_memory_formats:
if tensor.is_contiguous(memory_format=memory_format):
n_legality += 1
return n_legality != 1
def has_multiple_dim_order(tensor):
"""
Returns True if there're multiple legal dim orders for given tensor, False otherwise.
The tensor is considered to have multiple legal dim orders if either of the following conditions is met:
* Singleton Dimensions: There's at least one singleteon dimension in the tensor.
Since their size is 1, they don't affect the memory offset (stride * index
is zero because index is always zero). Therefore, they can be placed anywhere
in the dimension order without changing how data is accessed.
* Same strides: Strides reflect how the tensor is stored in memory.
If any two dimensions have the same stride, swapping these dimensions won't
change how data is accessed, leading to multiple correct dimension orders.
"""
sizes = tensor.size()
strides = tensor.stride()
# Check if there are any duplicate strides
has_duplicate_strides = any(
earlier == later for earlier, later in zip(strides, strides[1:])
)
# Check if there are any singleton dimensions
has_singleton_dims = any(size == 1 for size in sizes)
return has_duplicate_strides or has_singleton_dims
valid_memory_formats = (
ambiguity_check if isinstance(ambiguity_check, list) else []
)
check_multiple_dim_order = (
ambiguity_check if isinstance(ambiguity_check, bool) else True
)
if (
check_multiple_dim_order and has_multiple_dim_order(self)
) and invalid_unique_memory_format(self, valid_memory_formats):
raise RuntimeError(
"The tensor does not have unique dim order, or cannot map to exact one of the given memory formats."
)
import torch._prims_common as utils
return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))

View File

@ -1417,7 +1417,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.dense_dim: lambda self: -1,
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
Tensor.dim: lambda self: -1,
Tensor.dim_order: lambda self: -1,
Tensor.dim_order: lambda self, ambiguity_check=False: -1,
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
Tensor.element_size: lambda self: -1,