mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
make should_swap more dde friendly (#162099)
unblock customers for common cases with DDE ,until @pianpwk land the change to should_swap https://github.com/pytorch/pytorch/pull/160473. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162099 Approved by: https://github.com/aorenste ghstack dependencies: #162084
This commit is contained in:
committed by
PyTorch MergeBot
parent
fecd9686f5
commit
85fe94e933
@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
|
|||||||
def compute_elementwise_output_logical_to_physical_perm(
|
def compute_elementwise_output_logical_to_physical_perm(
|
||||||
*tensors, _skip_checks=False
|
*tensors, _skip_checks=False
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
guard_or_false,
|
||||||
|
guard_size_oblivious,
|
||||||
|
)
|
||||||
|
|
||||||
if not _skip_checks and len(tensors) == 0:
|
if not _skip_checks and len(tensors) == 0:
|
||||||
msg = "Can't compute elementwise output strides for zero tensors!"
|
msg = "Can't compute elementwise output strides for zero tensors!"
|
||||||
@ -595,12 +598,23 @@ def compute_elementwise_output_logical_to_physical_perm(
|
|||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
stride_a = tensor.stride()[idx_a]
|
stride_a = tensor.stride()[idx_a]
|
||||||
stride_b = tensor.stride()[idx_b]
|
stride_b = tensor.stride()[idx_b]
|
||||||
|
|
||||||
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
|
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
|
||||||
stride_b == 0
|
stride_b == 0
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if guard_or_false(stride_a == stride_b):
|
||||||
|
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# when stride_a = 1, we want stride_a < stride_b to be TRUE
|
||||||
|
# when stride_b = 1, we want stride_a < stride_b to be FALSE
|
||||||
|
elif guard_or_false(stride_a == 1):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
elif guard_or_false(stride_b == 1):
|
||||||
|
return 1
|
||||||
|
|
||||||
if guard_size_oblivious(stride_a < stride_b):
|
if guard_size_oblivious(stride_a < stride_b):
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user