mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)
when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors. in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want to use definitely _contiguous API. This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
54f1f29fed
commit
39df901b2a
@ -259,47 +259,64 @@ def check_all_strides(
|
||||
|
||||
|
||||
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
|
||||
def is_contiguous(a: TensorLikeType) -> bool:
|
||||
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
|
||||
"""
|
||||
Tests whether a tensor is contiguous or not.
|
||||
|
||||
Tensors are contiguous when they have no elements,
|
||||
one element, or when they have "nested" strides.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
if guard_size_oblivious(a.numel() < 2):
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
if maybe_guard_or_false(a.numel() < 2):
|
||||
return True
|
||||
|
||||
expected_stride = 1
|
||||
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
|
||||
# Skips checking strides when a dimension has length 1
|
||||
if guard_size_oblivious(x == 1):
|
||||
if maybe_guard_or_false(x == 1):
|
||||
continue
|
||||
|
||||
if guard_size_oblivious(y != expected_stride):
|
||||
if maybe_guard_or_true(y != expected_stride):
|
||||
return False
|
||||
expected_stride = expected_stride * x
|
||||
|
||||
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
|
||||
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
|
||||
expected_stride = expected_stride * sym_max(x, 1)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
|
||||
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
|
||||
# NHWC or not channels last 2D contiguous
|
||||
if a.ndim != 4:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if guard_size_oblivious(length == 1):
|
||||
if maybe_guard_or_false(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
@ -307,21 +324,28 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
|
||||
# NDHWC or not channels last 3D contiguous
|
||||
if a.ndim != 5:
|
||||
return False
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
|
||||
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
|
||||
|
||||
expected_stride = 1
|
||||
for idx in (1, 4, 3, 2, 0):
|
||||
length = a.shape[idx]
|
||||
if guard_size_oblivious(length == 1):
|
||||
if maybe_guard_or_false(length == 1):
|
||||
continue
|
||||
|
||||
stride = a.stride()[idx]
|
||||
if guard_size_oblivious(stride != expected_stride):
|
||||
if maybe_guard_or_true(stride != expected_stride):
|
||||
return False
|
||||
|
||||
expected_stride *= length
|
||||
@ -345,16 +369,16 @@ def validate_memory_format(memory_format: torch.memory_format):
|
||||
|
||||
|
||||
def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
|
||||
) -> bool:
|
||||
validate_memory_format(memory_format)
|
||||
|
||||
if memory_format == torch.contiguous_format:
|
||||
return is_contiguous(a)
|
||||
return is_contiguous(a, false_if_dde)
|
||||
if memory_format == torch.channels_last:
|
||||
return is_channels_last_contiguous_2d(a)
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde)
|
||||
if memory_format == torch.channels_last_3d:
|
||||
return is_channels_last_contiguous_3d(a)
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde)
|
||||
|
||||
torch._check(
|
||||
False,
|
||||
@ -362,6 +386,29 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
)
|
||||
|
||||
|
||||
def definitely_contiguous(a: TensorLikeType) -> bool:
|
||||
return is_contiguous(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_2d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous_3d but return false on data dependency.
|
||||
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_3d(a, false_if_dde=True)
|
||||
|
||||
|
||||
# similar to is_contiguous_for_memory_format but return false on data dependency.
|
||||
def definitely_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format
|
||||
) -> bool:
|
||||
return is_contiguous_for_memory_format(
|
||||
a, memory_format=memory_format, false_if_dde=True
|
||||
)
|
||||
|
||||
|
||||
# NOTE: that tensors with no elements and channels last is ???
|
||||
def is_channels_last_contiguous(a: Tensor) -> bool:
|
||||
"""
|
||||
@ -379,6 +426,13 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
# similar to is_channels_last_contiguous but return false on data dependency.
|
||||
def is_known_channels_last_contiguous(a: Tensor) -> bool:
|
||||
return is_known_channels_last_contiguous_2d(
|
||||
a
|
||||
) or is_known_channels_last_contiguous_3d(a)
|
||||
|
||||
|
||||
def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
"""
|
||||
True when a tensor is non-overlapping and dense.
|
||||
|
Reference in New Issue
Block a user