introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432)

when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors.
in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want
to use definitely _contiguous API.

This is appleid for reshape in this PR and also to  tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true  now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-05-27 13:24:57 -07:00
committed by PyTorch MergeBot
parent 54f1f29fed
commit 39df901b2a
10 changed files with 178 additions and 53 deletions

View File

@ -259,47 +259,64 @@ def check_all_strides(
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
def is_contiguous(a: TensorLikeType) -> bool:
def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool:
"""
Tests whether a tensor is contiguous or not.
Tensors are contiguous when they have no elements,
one element, or when they have "nested" strides.
"""
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
if guard_size_oblivious(a.numel() < 2):
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
if maybe_guard_or_false(a.numel() < 2):
return True
expected_stride = 1
for x, y in reversed(tuple(zip(a.shape, a.stride()))):
# Skips checking strides when a dimension has length 1
if guard_size_oblivious(x == 1):
if maybe_guard_or_false(x == 1):
continue
if guard_size_oblivious(y != expected_stride):
if maybe_guard_or_true(y != expected_stride):
return False
expected_stride = expected_stride * x
# if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can
# can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for.
expected_stride = expected_stride * sym_max(x, 1)
return True
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool:
# NHWC or not channels last 2D contiguous
if a.ndim != 4:
return False
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1
for idx in (1, 3, 2, 0):
length = a.shape[idx]
if guard_size_oblivious(length == 1):
if maybe_guard_or_false(length == 1):
continue
stride = a.stride()[idx]
if guard_size_oblivious(stride != expected_stride):
if maybe_guard_or_true(stride != expected_stride):
return False
expected_stride *= length
@ -307,21 +324,28 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool:
return True
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool:
# NDHWC or not channels last 3D contiguous
if a.ndim != 5:
return False
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
guard_size_oblivious,
)
maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious
maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious
expected_stride = 1
for idx in (1, 4, 3, 2, 0):
length = a.shape[idx]
if guard_size_oblivious(length == 1):
if maybe_guard_or_false(length == 1):
continue
stride = a.stride()[idx]
if guard_size_oblivious(stride != expected_stride):
if maybe_guard_or_true(stride != expected_stride):
return False
expected_stride *= length
@ -345,16 +369,16 @@ def validate_memory_format(memory_format: torch.memory_format):
def is_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
) -> bool:
validate_memory_format(memory_format)
if memory_format == torch.contiguous_format:
return is_contiguous(a)
return is_contiguous(a, false_if_dde)
if memory_format == torch.channels_last:
return is_channels_last_contiguous_2d(a)
return is_channels_last_contiguous_2d(a, false_if_dde)
if memory_format == torch.channels_last_3d:
return is_channels_last_contiguous_3d(a)
return is_channels_last_contiguous_3d(a, false_if_dde)
torch._check(
False,
@ -362,6 +386,29 @@ def is_contiguous_for_memory_format( # type: ignore[return]
)
def definitely_contiguous(a: TensorLikeType) -> bool:
return is_contiguous(a, false_if_dde=True)
# similar to is_channels_last_contiguous_2d but return false on data dependency.
def is_known_channels_last_contiguous_2d(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a, false_if_dde=True)
# similar to is_channels_last_contiguous_3d but return false on data dependency.
def is_known_channels_last_contiguous_3d(a: Tensor) -> bool:
return is_channels_last_contiguous_3d(a, false_if_dde=True)
# similar to is_contiguous_for_memory_format but return false on data dependency.
def definitely_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format
) -> bool:
return is_contiguous_for_memory_format(
a, memory_format=memory_format, false_if_dde=True
)
# NOTE: that tensors with no elements and channels last is ???
def is_channels_last_contiguous(a: Tensor) -> bool:
"""
@ -379,6 +426,13 @@ def is_channels_last_contiguous(a: Tensor) -> bool:
return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
# similar to is_channels_last_contiguous but return false on data dependency.
def is_known_channels_last_contiguous(a: Tensor) -> bool:
return is_known_channels_last_contiguous_2d(
a
) or is_known_channels_last_contiguous_3d(a)
def is_non_overlapping_and_dense(a: Tensor) -> bool:
"""
True when a tensor is non-overlapping and dense.