make should_swap more dde friendly (#162099)

unblock customers for common cases with DDE ,until @pianpwk  land the change to should_swap https://github.com/pytorch/pytorch/pull/160473.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162099
Approved by: https://github.com/aorenste
ghstack dependencies: #162084
This commit is contained in:
Laith Sakka
2025-09-08 11:20:46 -07:00
committed by PyTorch MergeBot
parent fecd9686f5
commit 85fe94e933

View File

@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
def compute_elementwise_output_logical_to_physical_perm(
*tensors, _skip_checks=False
) -> list[int]:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_size_oblivious,
)
if not _skip_checks and len(tensors) == 0:
msg = "Can't compute elementwise output strides for zero tensors!"
@ -595,12 +598,23 @@ def compute_elementwise_output_logical_to_physical_perm(
for tensor in tensors:
stride_a = tensor.stride()[idx_a]
stride_b = tensor.stride()[idx_b]
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
stride_b == 0
):
continue
if guard_or_false(stride_a == stride_b):
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
return 1
# when stride_a = 1, we want stride_a < stride_b to be TRUE
# when stride_b = 1, we want stride_a < stride_b to be FALSE
elif guard_or_false(stride_a == 1):
return -1
elif guard_or_false(stride_b == 1):
return 1
if guard_size_oblivious(stride_a < stride_b):
return -1