mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
If dynamic shapes are enabled, then block analysis may create new precomputed size replacements from the index which can lead to an assertion failure when the matched index is compared with the original index. For example the below assertion fails, despite the expressions being equivalent (ps2 = 3 * ps0). This can be resolved by updating the original index with the replacements, or simply removing the replacements when the expressions are tested to be equal - the latter option is implemented in this PR. ``` torch._inductor.exc.InductorError: AssertionError: E Invalid match! E Index: 3*ps0*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E Matched expression: ps2*((yindex//3)) + (ModularIndexing(yindex, 1, 3)) E ``` This PR fixes the test below when `config.triton.use_block_ptr=True`: ``` python test/inductor/test_torchinductor_dynamic_shapes.py DynamicShapesCpuTests.test_conv3d_channels_last_dynamic_shapes_cpu ``` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/148446 Approved by: https://github.com/jansel
138 lines
4.4 KiB
Python
138 lines
4.4 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
|
|
from torch._inductor.utils import sympy_dot
|
|
from torch._inductor.virtualized import V
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.inductor_utils import dummy_graph
|
|
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
|
|
|
|
|
# Some useful symbols
|
|
x, y = sympy.symbols("x y")
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class BlockAnalysisTest(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
# Create a GraphLowering, so we can access V.graph.
|
|
cls.graph = dummy_graph()
|
|
|
|
@parametrize(
|
|
"stride,symbol,expr",
|
|
[
|
|
(5, x, Identity(5 * x)),
|
|
(4, y, 4 * Identity(y)),
|
|
(3, x, Identity(3) * x),
|
|
],
|
|
)
|
|
def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
|
|
# Test that we can handle an identity expression in affine indexing.
|
|
matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
|
|
self.assertEqual(matched_stride, stride)
|
|
|
|
@parametrize(
|
|
"dims,strides,symbol,expr",
|
|
[
|
|
(
|
|
(2, 4),
|
|
(4, 1),
|
|
x,
|
|
4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
|
|
),
|
|
(
|
|
(3, 9),
|
|
(5, 2),
|
|
x,
|
|
5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
|
|
),
|
|
((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
|
|
],
|
|
)
|
|
def test_mod_div_identity(
|
|
self,
|
|
dims: tuple[int],
|
|
strides: tuple[int],
|
|
symbol: sympy.Symbol,
|
|
expr: sympy.Expr,
|
|
):
|
|
# Test that we can handle an identity expression in modular indexing.
|
|
numel = int(torch.prod(torch.Tensor(dims)))
|
|
num_dims = len(dims)
|
|
with V.set_graph_handler(self.graph):
|
|
match_result = BlockPatternMatcher.match_mod_div_block_expr(
|
|
expr, symbol, numel, num_dims
|
|
)
|
|
|
|
# Check the matched block dimensions.
|
|
self.assertNotEqual(match_result, None)
|
|
matched_dims, matched_strides, matched_block_index_exprs = match_result
|
|
self.assertEqual(matched_dims, dims)
|
|
self.assertEqual(matched_strides, strides)
|
|
|
|
@parametrize(
|
|
"symbol,expr,subexpr",
|
|
[
|
|
(x, Identity(x), x),
|
|
(x, Identity(x + 5), x),
|
|
(y, Identity(x + 2 * y) + 5, 2 * y),
|
|
],
|
|
)
|
|
def test_subexpr_identity(
|
|
self,
|
|
symbol: sympy.Symbol,
|
|
expr: sympy.Expr,
|
|
subexpr: sympy.Expr,
|
|
):
|
|
matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
|
|
self.assertEqual(matched_subexpr, subexpr)
|
|
|
|
def test_index_with_dynamic_shapes(self):
|
|
s0 = sympy.var("s0", integer=True)
|
|
s1 = sympy.var("s1", integer=True)
|
|
|
|
dims = [s1, sympy.Integer(3)]
|
|
num_dims = len(dims)
|
|
numel = dims[0] * dims[1]
|
|
strides = [sympy.Integer(3) * s0, sympy.Integer(1)]
|
|
block_index_exprs = [
|
|
FloorDiv(y, sympy.Integer(3)),
|
|
ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)),
|
|
]
|
|
index = sympy_dot(strides, block_index_exprs)
|
|
|
|
with V.set_graph_handler(self.graph):
|
|
match = BlockPatternMatcher.match_mod_div_block_expr(
|
|
index, y, numel, num_dims
|
|
)
|
|
sizevars = V.graph.sizevars
|
|
for expected, actual in zip((dims, strides, block_index_exprs), match):
|
|
assert isinstance(expected, (list, tuple)) and isinstance(
|
|
actual, (list, tuple)
|
|
)
|
|
for expected_expr, actual_expr in zip(expected, actual):
|
|
assert isinstance(expected_expr, sympy.Expr) and isinstance(
|
|
actual_expr, sympy.Expr
|
|
)
|
|
self.assertTrue(
|
|
sizevars.statically_known_equals(
|
|
sizevars.remove_precomputed_replacements(expected_expr),
|
|
sizevars.remove_precomputed_replacements(actual_expr),
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|