mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][triton] Block ptr analysis fix assert on matched index expression (#148446)
If dynamic shapes are enabled, then block analysis may create new precomputed size replacements from the index which can lead to an assertion failure when the matched index is compared with the original index. For example the below assertion fails, despite the expressions being equivalent (ps2 = 3 * ps0). This can be resolved by updating the original index with the replacements, or simply removing the replacements when the expressions are tested to be equal - the latter option is implemented in this PR. ``` torch._inductor.exc.InductorError: AssertionError: E Invalid match! E Index: 3*ps0*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E Matched expression: ps2*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E ``` This PR fixes the test below when `config.triton.use_block_ptr=True`: ``` python test/inductor/test_torchinductor_dynamic_shapes.py DynamicShapesCpuTests.test_conv3d_channels_last_dynamic_shapes_cpu ``` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/148446 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
3680e666d8
commit
00199acdb8
@ -4,6 +4,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
|
||||
from torch._inductor.utils import sympy_dot
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -97,6 +98,40 @@ class BlockAnalysisTest(TestCase):
|
||||
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
|
||||
self.assertEqual(matched_subexpr, subexpr)
|
||||
|
||||
def test_index_with_dynamic_shapes(self):
|
||||
s0 = sympy.var("s0", integer=True)
|
||||
s1 = sympy.var("s1", integer=True)
|
||||
|
||||
dims = [s1, sympy.Integer(3)]
|
||||
num_dims = len(dims)
|
||||
numel = dims[0] * dims[1]
|
||||
strides = [sympy.Integer(3) * s0, sympy.Integer(1)]
|
||||
block_index_exprs = [
|
||||
FloorDiv(y, sympy.Integer(3)),
|
||||
ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)),
|
||||
]
|
||||
index = sympy_dot(strides, block_index_exprs)
|
||||
|
||||
with V.set_graph_handler(self.graph):
|
||||
match = BlockPatternMatcher.match_mod_div_block_expr(
|
||||
index, y, numel, num_dims
|
||||
)
|
||||
sizevars = V.graph.sizevars
|
||||
for expected, actual in zip((dims, strides, block_index_exprs), match):
|
||||
assert isinstance(expected, (list, tuple)) and isinstance(
|
||||
actual, (list, tuple)
|
||||
)
|
||||
for expected_expr, actual_expr in zip(expected, actual):
|
||||
assert isinstance(expected_expr, sympy.Expr) and isinstance(
|
||||
actual_expr, sympy.Expr
|
||||
)
|
||||
self.assertTrue(
|
||||
sizevars.statically_known_equals(
|
||||
sizevars.remove_precomputed_replacements(expected_expr),
|
||||
sizevars.remove_precomputed_replacements(actual_expr),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -4603,7 +4603,11 @@ class CommonTemplate:
|
||||
check_lowp=False,
|
||||
)
|
||||
|
||||
def test_conv3d_channels_last(self):
|
||||
@parametrize(
|
||||
"use_block_ptr",
|
||||
[subtest(False), subtest(True, decorators=[skip_if_not_triton])],
|
||||
)
|
||||
def test_conv3d_channels_last(self, use_block_ptr: bool):
|
||||
if self.device == GPU_TYPE:
|
||||
raise unittest.SkipTest("only support cpu conv3d channels_last")
|
||||
|
||||
@ -4611,21 +4615,30 @@ class CommonTemplate:
|
||||
torch.nn.Conv3d(3, 3, 1, 1),
|
||||
ToTuple(),
|
||||
)
|
||||
# only weight is channels_last
|
||||
self.common(
|
||||
m.to(memory_format=torch.channels_last_3d),
|
||||
(torch.randn([2, 3, 16, 16, 16]),),
|
||||
)
|
||||
# only activation is channels_last
|
||||
self.common(
|
||||
m,
|
||||
(torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
|
||||
)
|
||||
# activation and weight are all channels_last
|
||||
self.common(
|
||||
m.to(memory_format=torch.channels_last_3d),
|
||||
(torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),),
|
||||
)
|
||||
with config.patch({"triton.use_block_ptr": use_block_ptr}):
|
||||
# only weight is channels_last
|
||||
self.common(
|
||||
m.to(memory_format=torch.channels_last_3d),
|
||||
(torch.randn([2, 3, 16, 16, 16]),),
|
||||
)
|
||||
# only activation is channels_last
|
||||
self.common(
|
||||
m,
|
||||
(
|
||||
torch.randn([2, 3, 16, 16, 16]).to(
|
||||
memory_format=torch.channels_last_3d
|
||||
),
|
||||
),
|
||||
)
|
||||
# activation and weight are all channels_last
|
||||
self.common(
|
||||
m.to(memory_format=torch.channels_last_3d),
|
||||
(
|
||||
torch.randn([2, 3, 16, 16, 16]).to(
|
||||
memory_format=torch.channels_last_3d
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@skip_if_gpu_halide # slow
|
||||
@xfail_if_mps # Non-divisible input sizes are not implemented on MPS device
|
||||
|
@ -113,7 +113,9 @@ test_failures = {
|
||||
"test_clamp_type_promotion_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_conv3d_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_conv3d_channels_last_use_block_ptr_False_dynamic_shapes": TestFailure(
|
||||
("cpu",)
|
||||
),
|
||||
"test_expand_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_full_boolean_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_glu_dynamic_shapes": TestFailure(("cpu",)),
|
||||
|
@ -139,7 +139,14 @@ class BlockPatternMatcher:
|
||||
|
||||
# Sanity check that we can recover the index from the matched subexpressions.
|
||||
matched_index = sympy_dot(strides, block_index_exprs)
|
||||
assert sizevars.statically_known_equals(matched_index, index), textwrap.dedent(
|
||||
assert sizevars.statically_known_equals(
|
||||
# New precomputed replacements may be generated when the `get_match` function
|
||||
# above is called, but the `index` that is being matched has not been updated.
|
||||
# So remove them when checking for equivalence e.g. if ps0=3*s0 and
|
||||
# index=3*s0*expr, matched_index=ps0*expr, then index == matched_index
|
||||
sizevars.remove_precomputed_replacements(matched_index),
|
||||
sizevars.remove_precomputed_replacements(index),
|
||||
), textwrap.dedent(
|
||||
f"""
|
||||
Invalid match!
|
||||
Index: {index}
|
||||
|
Reference in New Issue
Block a user