mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
(should_fold) gso to guard_or_false when checking folding whether to 3d bmm into 2d mm (#159184)
Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.
806d9e3fe7/torch/_decomp/decompositions.py (L4506-L4512)
## DDE
```
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
elif should_fold(tensor1, tensor2, is_out):
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)). (Size-like symbols: none)
Caused by: (_decomp/decompositions.py:4472 in should_fold)
```
```
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
elif should_fold(tensor1, tensor2, is_out):
File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)). (Size-like symbols: none)
Caused by: (_decomp/decompositions.py:4483 in should_fold)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159184
Approved by: https://github.com/ezyang
ghstack dependencies: #158894
This commit is contained in:
committed by
PyTorch MergeBot
parent
880249adbc
commit
46d34d6766
@ -4363,6 +4363,20 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
self.assertTrue(torch.allclose(ref[0], actual[0]))
|
||||
self.assertTrue(torch.allclose(ref[1], actual[1]))
|
||||
|
||||
def test_unbacked_3d_matmul(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, repeat):
|
||||
u0 = repeat.item()
|
||||
t1 = x.unsqueeze(1).expand(x.size(0), u0 // 2, x.size(-1))
|
||||
t2 = torch.ones(3)
|
||||
return torch.matmul(t1, t2)
|
||||
|
||||
model = Model()
|
||||
inputs = (torch.randn(4, 3), torch.scalar_tensor(2, dtype=torch.int))
|
||||
|
||||
exported = export(model, inputs).module()
|
||||
self.assertEqual(model(*inputs), exported(*inputs))
|
||||
|
||||
def test_dynamic_shapes_builder_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
|
@ -4461,7 +4461,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
|
||||
t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if not (t1.ndim >= 3 and t2.ndim <= 2):
|
||||
return False
|
||||
@ -4469,7 +4469,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
return True
|
||||
if tensor1.ndim == 2:
|
||||
return False
|
||||
if guard_size_oblivious(t1.numel() == 0):
|
||||
if guard_or_false(t1.numel() == 0):
|
||||
return True
|
||||
|
||||
t1_shape = t1.shape
|
||||
@ -4481,7 +4481,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
|
||||
for size in reversed(t1_shape[1:]):
|
||||
expected_stride.append(size * expected_stride[-1])
|
||||
return all(
|
||||
guard_size_oblivious(size == 1) or left == right
|
||||
guard_or_false(size == 1) or guard_or_false(left == right)
|
||||
for left, right, size in zip(
|
||||
t1_stride, list(reversed(expected_stride)), t1_shape
|
||||
)
|
||||
|
Reference in New Issue
Block a user