mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 10:04:58 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -1191,14 +1191,16 @@ class TestTiling(TestCase):
|
||||
"""Test broadcast variable detection for tiling improvements."""
|
||||
from torch._inductor import tiling_utils
|
||||
|
||||
i, j = sympy.symbols("i j", integer=True)
|
||||
i, j, k = sympy.symbols("i j k", integer=True)
|
||||
|
||||
# Test broadcast pattern detection: FloorDiv creates broadcast
|
||||
result = tiling_utils.find_broadcast_var(FloorDiv(i, 10), {i: 100, j: 50})
|
||||
result = tiling_utils.find_broadcast_var(
|
||||
FloorDiv(i, 10), {i: 100, j: 50, k: 20}
|
||||
)
|
||||
self.assertEqual(result, i)
|
||||
|
||||
# Test non-broadcast: linear access pattern
|
||||
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8})
|
||||
result = tiling_utils.find_broadcast_var(i + j * 10, {i: 10, j: 8, k: 20})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
|
||||
|
||||
@ -162,6 +162,9 @@ def find_broadcast_var(
|
||||
|
||||
zero_index = sympy_subs(index, variables)
|
||||
for v in var_ranges.keys():
|
||||
if v not in index.free_symbols:
|
||||
continue
|
||||
|
||||
variables[v] = 1
|
||||
try:
|
||||
new_val = sympy_subs(index, variables)
|
||||
|
||||
Reference in New Issue
Block a user