diff --git a/test/inductor/test_block_analysis.py b/test/inductor/test_block_analysis.py index 5cf932d52e89..3d2cb0373c43 100644 --- a/test/inductor/test_block_analysis.py +++ b/test/inductor/test_block_analysis.py @@ -4,6 +4,7 @@ import sympy import torch from torch._inductor.codegen.block_analysis import BlockPatternMatcher +from torch._inductor.utils import sympy_dot from torch._inductor.virtualized import V from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -97,6 +98,40 @@ class BlockAnalysisTest(TestCase): matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol) self.assertEqual(matched_subexpr, subexpr) + def test_index_with_dynamic_shapes(self): + s0 = sympy.var("s0", integer=True) + s1 = sympy.var("s1", integer=True) + + dims = [s1, sympy.Integer(3)] + num_dims = len(dims) + numel = dims[0] * dims[1] + strides = [sympy.Integer(3) * s0, sympy.Integer(1)] + block_index_exprs = [ + FloorDiv(y, sympy.Integer(3)), + ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)), + ] + index = sympy_dot(strides, block_index_exprs) + + with V.set_graph_handler(self.graph): + match = BlockPatternMatcher.match_mod_div_block_expr( + index, y, numel, num_dims + ) + sizevars = V.graph.sizevars + for expected, actual in zip((dims, strides, block_index_exprs), match): + assert isinstance(expected, (list, tuple)) and isinstance( + actual, (list, tuple) + ) + for expected_expr, actual_expr in zip(expected, actual): + assert isinstance(expected_expr, sympy.Expr) and isinstance( + actual_expr, sympy.Expr + ) + self.assertTrue( + sizevars.statically_known_equals( + sizevars.remove_precomputed_replacements(expected_expr), + sizevars.remove_precomputed_replacements(actual_expr), + ) + ) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 67beeb832f59..ca9a693bf97f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4603,7 +4603,11 @@ class CommonTemplate: check_lowp=False, ) - def test_conv3d_channels_last(self): + @parametrize( + "use_block_ptr", + [subtest(False), subtest(True, decorators=[skip_if_not_triton])], + ) + def test_conv3d_channels_last(self, use_block_ptr: bool): if self.device == GPU_TYPE: raise unittest.SkipTest("only support cpu conv3d channels_last") @@ -4611,21 +4615,30 @@ class CommonTemplate: torch.nn.Conv3d(3, 3, 1, 1), ToTuple(), ) - # only weight is channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]),), - ) - # only activation is channels_last - self.common( - m, - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) - # activation and weight are all channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) + with config.patch({"triton.use_block_ptr": use_block_ptr}): + # only weight is channels_last + self.common( + m.to(memory_format=torch.channels_last_3d), + (torch.randn([2, 3, 16, 16, 16]),), + ) + # only activation is channels_last + self.common( + m, + ( + torch.randn([2, 3, 16, 16, 16]).to( + memory_format=torch.channels_last_3d + ), + ), + ) + # activation and weight are all channels_last + self.common( + m.to(memory_format=torch.channels_last_3d), + ( + torch.randn([2, 3, 16, 16, 16]).to( + memory_format=torch.channels_last_3d + ), + ), + ) @skip_if_gpu_halide # slow @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 87fad2a0068e..3704a1211c70 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -113,7 +113,9 @@ test_failures = { "test_clamp_type_promotion_dynamic_shapes": TestFailure(("cpu",)), "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv3d_dynamic_shapes": TestFailure(("cpu",)), - "test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)), + "test_conv3d_channels_last_use_block_ptr_False_dynamic_shapes": TestFailure( + ("cpu",) + ), "test_expand_dynamic_shapes": TestFailure(("cpu",)), "test_full_boolean_dynamic_shapes": TestFailure(("cpu",)), "test_glu_dynamic_shapes": TestFailure(("cpu",)), diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index 1c816eb8e293..b99f7f786cff 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -139,7 +139,14 @@ class BlockPatternMatcher: # Sanity check that we can recover the index from the matched subexpressions. matched_index = sympy_dot(strides, block_index_exprs) - assert sizevars.statically_known_equals(matched_index, index), textwrap.dedent( + assert sizevars.statically_known_equals( + # New precomputed replacements may be generated when the `get_match` function + # above is called, but the `index` that is being matched has not been updated. + # So remove them when checking for equivalence e.g. if ps0=3*s0 and + # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index + sizevars.remove_precomputed_replacements(matched_index), + sizevars.remove_precomputed_replacements(index), + ), textwrap.dedent( f""" Invalid match! Index: {index}