mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
(should_fold) gso to guard_or_false when checking folding whether to 3d bmm into 2d mm (#159184)
Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.
806d9e3fe7/torch/_decomp/decompositions.py (L4506-L4512)
## DDE
```
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
elif should_fold(tensor1, tensor2, is_out):
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)). (Size-like symbols: none)
Caused by: (_decomp/decompositions.py:4472 in should_fold)
```
```
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
elif should_fold(tensor1, tensor2, is_out):
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)). (Size-like symbols: none)
Caused by: (_decomp/decompositions.py:4483 in should_fold)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159184
Approved by: https://github.com/ezyang
ghstack dependencies: #158894
This commit is contained in:
committed by
PyTorch MergeBot
parent
880249adbc
commit
46d34d6766
@ -4461,7 +4461,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
|
||||
t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if not (t1.ndim >= 3 and t2.ndim <= 2):
|
||||
return False
|
||||
@ -4469,7 +4469,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
return True
|
||||
if tensor1.ndim == 2:
|
||||
return False
|
||||
if guard_size_oblivious(t1.numel() == 0):
|
||||
if guard_or_false(t1.numel() == 0):
|
||||
return True
|
||||
|
||||
t1_shape = t1.shape
|
||||
@ -4481,7 +4481,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
for size in reversed(t1_shape[1:]):
|
||||
expected_stride.append(size * expected_stride[-1])
|
||||
return all(
|
||||
guard_size_oblivious(size == 1) or left == right
|
||||
guard_or_false(size == 1) or guard_or_false(left == right)
|
||||
for left, right, size in zip(
|
||||
t1_stride, list(reversed(expected_stride)), t1_shape
|
||||
)
|
||||
|
Reference in New Issue
Block a user