mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make should_swap more dde friendly (#162099)
unblock customers for common cases with DDE ,until @pianpwk land the change to should_swap https://github.com/pytorch/pytorch/pull/160473. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162099 Approved by: https://github.com/aorenste ghstack dependencies: #162084
This commit is contained in:
committed by
PyTorch MergeBot
parent
fecd9686f5
commit
85fe94e933
@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
||||
def compute_elementwise_output_logical_to_physical_perm(
|
||||
*tensors, _skip_checks=False
|
||||
) -> list[int]:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
)
|
||||
|
||||
if not _skip_checks and len(tensors) == 0:
|
||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||
@ -595,12 +598,23 @@ def compute_elementwise_output_logical_to_physical_perm(
|
||||
for tensor in tensors:
|
||||
stride_a = tensor.stride()[idx_a]
|
||||
stride_b = tensor.stride()[idx_b]
|
||||
|
||||
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
|
||||
stride_b == 0
|
||||
):
|
||||
continue
|
||||
|
||||
if guard_or_false(stride_a == stride_b):
|
||||
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
||||
return 1
|
||||
|
||||
# when stride_a = 1, we want stride_a < stride_b to be TRUE
|
||||
# when stride_b = 1, we want stride_a < stride_b to be FALSE
|
||||
elif guard_or_false(stride_a == 1):
|
||||
return -1
|
||||
|
||||
elif guard_or_false(stride_b == 1):
|
||||
return 1
|
||||
|
||||
if guard_size_oblivious(stride_a < stride_b):
|
||||
return -1
|
||||
|
||||
|
Reference in New Issue
Block a user