(should_fold) gso to guard_or_false when checking folding whether to 3d bmm into 2d mm (#159184)

Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.

806d9e3fe7/torch/_decomp/decompositions.py (L4506-L4512)

## DDE
```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
    if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4472 in should_fold)
```

```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
    return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4483 in should_fold)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159184
Approved by: https://github.com/ezyang
ghstack dependencies: #158894
This commit is contained in:
Colin Peppler
2025-07-25 13:43:06 -07:00
committed by PyTorch MergeBot
parent 880249adbc
commit 46d34d6766
2 changed files with 17 additions and 3 deletions

View File

@ -4363,6 +4363,20 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertTrue(torch.allclose(ref[0], actual[0]))
self.assertTrue(torch.allclose(ref[1], actual[1]))
def test_unbacked_3d_matmul(self):
class Model(torch.nn.Module):
def forward(self, x, repeat):
u0 = repeat.item()
t1 = x.unsqueeze(1).expand(x.size(0), u0 // 2, x.size(-1))
t2 = torch.ones(3)
return torch.matmul(t1, t2)
model = Model()
inputs = (torch.randn(4, 3), torch.scalar_tensor(2, dtype=torch.int))
exported = export(model, inputs).module()
self.assertEqual(model(*inputs), exported(*inputs))
def test_dynamic_shapes_builder_basic(self):
class M(torch.nn.Module):
def forward(self, x, y, z):

View File

@ -4461,7 +4461,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import guard_or_false
if not (t1.ndim >= 3 and t2.ndim <= 2):
return False
@ -4469,7 +4469,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
return True
if tensor1.ndim == 2:
return False
if guard_size_oblivious(t1.numel() == 0):
if guard_or_false(t1.numel() == 0):
return True
t1_shape = t1.shape
@ -4481,7 +4481,7 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b
for size in reversed(t1_shape[1:]):
expected_stride.append(size * expected_stride[-1])
return all(
guard_size_oblivious(size == 1) or left == right
guard_or_false(size == 1) or guard_or_false(left == right)
for left, right, size in zip(
t1_stride, list(reversed(expected_stride)), t1_shape
)