Files
pytorch/test/inductor/test_block_analysis.py
Mwiza Kunda 00199acdb8 [inductor][triton] Block ptr analysis fix assert on matched index expression (#148446)
If dynamic shapes are enabled, then block analysis may create new precomputed size replacements from the index which can lead to an assertion failure when the matched index is compared with the original index. For example the below assertion fails, despite the expressions being equivalent (ps2 = 3 * ps0). This can be resolved by updating the original index with the replacements, or simply removing the replacements when the expressions are tested to be equal - the latter option is implemented in this PR.

```
       torch._inductor.exc.InductorError: AssertionError:
E       Invalid match!
E       Index: 3*ps0*((yindex//3)) + (ModularIndexing(yindex, 1, 3))
E       Matched expression: ps2*((yindex//3)) + (ModularIndexing(yindex, 1, 3))
E
```

This PR fixes the test below when `config.triton.use_block_ptr=True`:
```
python test/inductor/test_torchinductor_dynamic_shapes.py DynamicShapesCpuTests.test_conv3d_channels_last_dynamic_shapes_cpu
```

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148446
Approved by: https://github.com/jansel
2025-03-10 05:26:55 +00:00

138 lines
4.4 KiB
Python

# Owner(s): ["module: inductor"]
import sympy
import torch
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
from torch._inductor.utils import sympy_dot
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import dummy_graph
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
# Some useful symbols
x, y = sympy.symbols("x y")
@instantiate_parametrized_tests
class BlockAnalysisTest(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Create a GraphLowering, so we can access V.graph.
cls.graph = dummy_graph()
@parametrize(
"stride,symbol,expr",
[
(5, x, Identity(5 * x)),
(4, y, 4 * Identity(y)),
(3, x, Identity(3) * x),
],
)
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
# Test that we can handle an identity expression in affine indexing.
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
self.assertEqual(matched_stride, stride)
@parametrize(
"dims,strides,symbol,expr",
[
(
(2, 4),
(4, 1),
x,
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
),
(
(3, 9),
(5, 2),
x,
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
),
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
],
)
def test_mod_div_identity(
self,
dims: tuple[int],
strides: tuple[int],
symbol: sympy.Symbol,
expr: sympy.Expr,
):
# Test that we can handle an identity expression in modular indexing.
numel = int(torch.prod(torch.Tensor(dims)))
num_dims = len(dims)
with V.set_graph_handler(self.graph):
match_result = BlockPatternMatcher.match_mod_div_block_expr(
expr, symbol, numel, num_dims
)
# Check the matched block dimensions.
self.assertNotEqual(match_result, None)
matched_dims, matched_strides, matched_block_index_exprs = match_result
self.assertEqual(matched_dims, dims)
self.assertEqual(matched_strides, strides)
@parametrize(
"symbol,expr,subexpr",
[
(x, Identity(x), x),
(x, Identity(x + 5), x),
(y, Identity(x + 2 * y) + 5, 2 * y),
],
)
def test_subexpr_identity(
self,
symbol: sympy.Symbol,
expr: sympy.Expr,
subexpr: sympy.Expr,
):
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
self.assertEqual(matched_subexpr, subexpr)
def test_index_with_dynamic_shapes(self):
s0 = sympy.var("s0", integer=True)
s1 = sympy.var("s1", integer=True)
dims = [s1, sympy.Integer(3)]
num_dims = len(dims)
numel = dims[0] * dims[1]
strides = [sympy.Integer(3) * s0, sympy.Integer(1)]
block_index_exprs = [
FloorDiv(y, sympy.Integer(3)),
ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)),
]
index = sympy_dot(strides, block_index_exprs)
with V.set_graph_handler(self.graph):
match = BlockPatternMatcher.match_mod_div_block_expr(
index, y, numel, num_dims
)
sizevars = V.graph.sizevars
for expected, actual in zip((dims, strides, block_index_exprs), match):
assert isinstance(expected, (list, tuple)) and isinstance(
actual, (list, tuple)
)
for expected_expr, actual_expr in zip(expected, actual):
assert isinstance(expected_expr, sympy.Expr) and isinstance(
actual_expr, sympy.Expr
)
self.assertTrue(
sizevars.statically_known_equals(
sizevars.remove_precomputed_replacements(expected_expr),
sizevars.remove_precomputed_replacements(actual_expr),
)
)
if __name__ == "__main__":
run_tests()