From 00199acdb85a4355612bff28e1018b035e0e46b9 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Mon, 10 Mar 2025 05:26:55 +0000 Subject: [PATCH] [inductor][triton] Block ptr analysis fix assert on matched index expression (#148446) If dynamic shapes are enabled, then block analysis may create new precomputed size replacements from the index which can lead to an assertion failure when the matched index is compared with the original index. For example the below assertion fails, despite the expressions being equivalent (ps2 = 3 * ps0). This can be resolved by updating the original index with the replacements, or simply removing the replacements when the expressions are tested to be equal - the latter option is implemented in this PR. ``` torch._inductor.exc.InductorError: AssertionError: E Invalid match! E Index: 3*ps0*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E Matched expression: ps2*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E ``` This PR fixes the test below when `config.triton.use_block_ptr=True`: ``` python test/inductor/test_torchinductor_dynamic_shapes.py DynamicShapesCpuTests.test_conv3d_channels_last_dynamic_shapes_cpu ``` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/148446 Approved by: https://github.com/jansel --- test/inductor/test_block_analysis.py | 35 +++++++++++++++ test/inductor/test_torchinductor.py | 45 ++++++++++++------- ...st_torchinductor_codegen_dynamic_shapes.py | 4 +- torch/_inductor/codegen/block_analysis.py | 9 +++- 4 files changed, 75 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py index 5cf932d52e89..3d2cb0373c43 100644 --- a/test/inductor/test_block_analysis.py +++ b/test/inductor/test_block_analysis.py @@ -4,6 +4,7 @@ import sympy import torch from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.utils import sympy_dot from torch._inductor.virtualized import V from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -97,6 +98,40 @@ class BlockAnalysisTest(TestCase): matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) self.assertEqual(matched_subexpr, subexpr) + def test_index_with_dynamic_shapes(self): + s0 = sympy.var("s0", integer=True) + s1 = sympy.var("s1", integer=True) + + dims = [s1, sympy.Integer(3)] + num_dims = len(dims) + numel = dims[0] * dims[1] + strides = [sympy.Integer(3) * s0, sympy.Integer(1)] + block_index_exprs = [ + FloorDiv(y, sympy.Integer(3)), + ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)), + ] + index = sympy_dot(strides, block_index_exprs) + + with V.set_graph_handler(self.graph): + match = BlockPatternMatcher.match_mod_div_block_expr( + index, y, numel, num_dims + ) + sizevars = V.graph.sizevars + for expected, actual in zip((dims, strides, block_index_exprs), match): + assert isinstance(expected, (list, tuple)) and isinstance( + actual, (list, tuple) + ) + for expected_expr, actual_expr in zip(expected, actual): + assert isinstance(expected_expr, sympy.Expr) and isinstance( + actual_expr, sympy.Expr + ) + self.assertTrue( + sizevars.statically_known_equals( + sizevars.remove_precomputed_replacements(expected_expr), + sizevars.remove_precomputed_replacements(actual_expr), + ) + ) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 67beeb832f59..ca9a693bf97f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4603,7 +4603,11 @@ class CommonTemplate: check_lowp=False, ) - def test_conv3d_channels_last(self): + @parametrize( + "use_block_ptr", + [subtest(False), subtest(True, decorators=[skip_if_not_triton])], + ) + def test_conv3d_channels_last(self, use_block_ptr: bool): if self.device == GPU_TYPE: raise unittest.SkipTest("only support cpu conv3d channels_last") @@ -4611,21 +4615,30 @@ class CommonTemplate: torch.nn.Conv3d(3, 3, 1, 1), ToTuple(), ) - # only weight is channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]),), - ) - # only activation is channels_last - self.common( - m, - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) - # activation and weight are all channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) + with config.patch({"triton.use_block_ptr": use_block_ptr}): + # only weight is channels_last + self.common( + m.to(memory_format=torch.channels_last_3d), + (torch.randn([2, 3, 16, 16, 16]),), + ) + # only activation is channels_last + self.common( + m, + ( + torch.randn([2, 3, 16, 16, 16]).to( + memory_format=torch.channels_last_3d + ), + ), + ) + # activation and weight are all channels_last + self.common( + m.to(memory_format=torch.channels_last_3d), + ( + torch.randn([2, 3, 16, 16, 16]).to( + memory_format=torch.channels_last_3d + ), + ), + ) @skip_if_gpu_halide # slow @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 87fad2a0068e..3704a1211c70 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -113,7 +113,9 @@ test_failures = { "test_clamp_type_promotion_dynamic_shapes": TestFailure(("cpu",)), "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv3d_dynamic_shapes": TestFailure(("cpu",)), - "test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)), + "test_conv3d_channels_last_use_block_ptr_False_dynamic_shapes": TestFailure( + ("cpu",) + ), "test_expand_dynamic_shapes": TestFailure(("cpu",)), "test_full_boolean_dynamic_shapes": TestFailure(("cpu",)), "test_glu_dynamic_shapes": TestFailure(("cpu",)), diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 1c816eb8e293..b99f7f786cff 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -139,7 +139,14 @@ class BlockPatternMatcher: # Sanity check that we can recover the index from the matched subexpressions. matched_index = sympy_dot(strides, block_index_exprs) - assert sizevars.statically_known_equals(matched_index, index), textwrap.dedent( + assert sizevars.statically_known_equals( + # New precomputed replacements may be generated when the `get_match` function + # above is called, but the `index` that is being matched has not been updated. + # So remove them when checking for equivalence e.g. if ps0=3*s0 and + # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index + sizevars.remove_precomputed_replacements(matched_index), + sizevars.remove_precomputed_replacements(index), + ), textwrap.dedent( f""" Invalid match! Index: {index}