[Inductor] Fix ComboKernels failing due to missing helper functions (#162759)

Fixes: #162756

Differential Revision: [D82257359](https://our.internmc.facebook.com/intern/diff/D82257359)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162759
Approved by: https://github.com/eellison, https://github.com/mlazos
This commit is contained in:
karthickai
2025-09-11 15:05:32 -07:00
committed by PyTorch MergeBot
parent 38afeb2ba2
commit fa4d5e76ea
2 changed files with 27 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import unittest
import torch
import torch._inductor
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TestCase,
@ -554,6 +555,24 @@ class ComboKernelDynamicShapesTests(TestCase):
self.assertEqual(out_eager, out_compiled)
@requires_cuda_and_triton
def test_helper_fn_defined(self):
def fn(x, y, z):
return x.sum(1), y.mean(1), z.cumsum(1)
inps = (
torch.rand(16, 128, device="cuda"),
torch.rand(32, 128, device="cuda"),
torch.rand(32, 256, device="cuda"),
)
out_eager = fn(*inps)
fn_c = torch.compile(fn)
out_compiled, code = run_and_get_code(fn_c, *inps)
code = " ".join(code)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(code.count("def _triton_helper_fn_add0(arg0_0, arg1_0):"), 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -764,6 +764,14 @@ class ComboKernel(Kernel):
if config.benchmark_combo_kernel:
code.splice(self.imports_for_benchmark_kernel())
seen_helpers: OrderedSet[str] = OrderedSet()
for sub_kernel in self.sub_kernels:
for helper in sub_kernel.helper_functions:
if helper not in seen_helpers:
code.writeline("")
code.splice(helper)
seen_helpers.add(helper)
argdefs, _, signature, _ = self.args.python_argdefs()
argdefs = self.add_numel_to_args(argdefs, signature)
block_args = self.get_block_args()