Compare commits

...

1 Commits

Author SHA1 Message Date
917382f096 [multi-kernel] apply size hints atomically (#164863)
Summary:

atomic applies for hint_override

Test Plan: test_multi_kernel

Differential Revision: D84089824
2025-10-08 15:45:42 -07:00
3 changed files with 39 additions and 9 deletions

View File

@ -164,6 +164,37 @@ class MultiKernelTest(TestCase):
self.assertEqual(ref, act)
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
@requires_triton()
@skipIfRocm
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
@config.patch({"multi_kernel_hints": [8, 32, 128]}) # 4096 is too large
def test_mm_on_view_of_dynamic_base(self):
def fn(x, y):
x = x.view(-1, 4096) # [m, n, d] -> [m*n, d]
y = y.view(4096, -1) # [d, k, l] -> [d, k*l]
z = x @ y # [m*n, k*l]
return z
compiled_fn = torch.compile(
fn,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
x = torch.randn(64, 32, 4096, device=GPU_TYPE)
y = torch.randn(4096, 16, 8, device=GPU_TYPE)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
torch._dynamo.mark_dynamic(y, 1)
torch._dynamo.mark_dynamic(y, 2)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y) # noqa: F841
wrapper_code = wrapper_code[-1]
# self.assertEqual(ref, act) # this actually fails without multi-kernel
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}

View File

@ -1736,12 +1736,7 @@ class SIMDScheduling(BaseScheduling):
"""
shapes = self._get_multikernel_shapes(node)
return tuple(
tuple(
hint
if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
else s
for s in shape
)
tuple(V.graph.sizevars.size_hint(s, hint_override=hint) for s in shape)
for shape in shapes
)

View File

@ -573,6 +573,7 @@ class SizeVarAllocator:
return expr
# Substitute all hints into expr, but leave unbacked symints alone
expr = self.simplify(expr)
expr = self.remove_precomputed_replacements(expr)
if not isinstance(expr, Expr):
assert isinstance(expr, int)
return expr
@ -584,9 +585,12 @@ class SizeVarAllocator:
return expr # inf/nan/I
if hint_override:
return hint_override
expr = self.remove_precomputed_replacements(expr)
out = sympy_subs(
expr,
{symbol: sympy.Integer(hint_override) for symbol in free_symbols},
)
assert isinstance(out, sympy.Integer)
return out
if use_user_provided_hint_override:
expr = sympy_subs(expr, self.var_to_hint_override)