Update (base update)

[ghstack-poisoned]
This commit is contained in:
eellison
2025-11-17 14:01:20 -08:00
parent 4991994196
commit be42494d92
2 changed files with 8 additions and 3 deletions

View File

@ -1191,14 +1191,16 @@ class TestTiling(TestCase):
"""Test broadcast variable detection for tiling improvements."""
from torch._inductor import tiling_utils
i, j = sympy.symbols("i j", integer=True)
i, j, k = sympy.symbols("i j k", integer=True)
# Test broadcast pattern detection: FloorDiv creates broadcast
result = tiling_utils.find_broadcast_var(FloorDiv(i, 10), {i: 100, j: 50})
result = tiling_utils.find_broadcast_var(
FloorDiv(i, 10), {i: 100, j: 50, k: 20}
)
self.assertEqual(result, i)
# Test non-broadcast: linear access pattern
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8})
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8, k: 20})
self.assertEqual(result, None)

View File

@ -162,6 +162,9 @@ def find_broadcast_var(
zero_index = sympy_subs(index, variables)
for v in var_ranges.keys():
if v not in index.free_symbols:
continue
variables[v] = 1
try:
new_val = sympy_subs(index, variables)