diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index 90399546d26e..59187c7349a0 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -6,6 +6,7 @@ import unittest import torch import torch._inductor +from torch._inductor.utils import run_and_get_code from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, TestCase, @@ -554,6 +555,24 @@ class ComboKernelDynamicShapesTests(TestCase): self.assertEqual(out_eager, out_compiled) + @requires_cuda_and_triton + def test_helper_fn_defined(self): + def fn(x, y, z): + return x.sum(1), y.mean(1), z.cumsum(1) + + inps = ( + torch.rand(16, 128, device="cuda"), + torch.rand(32, 128, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + + out_eager = fn(*inps) + fn_c = torch.compile(fn) + out_compiled, code = run_and_get_code(fn_c, *inps) + code = " ".join(code) + self.assertEqual(out_eager, out_compiled) + self.assertEqual(code.count("def _triton_helper_fn_add0(arg0_0, arg1_0):"), 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index e3df5bc0363d..c28321923c5e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -764,6 +764,14 @@ class ComboKernel(Kernel): if config.benchmark_combo_kernel: code.splice(self.imports_for_benchmark_kernel()) + seen_helpers: OrderedSet[str] = OrderedSet() + for sub_kernel in self.sub_kernels: + for helper in sub_kernel.helper_functions: + if helper not in seen_helpers: + code.writeline("") + code.splice(helper) + seen_helpers.add(helper) + argdefs, _, signature, _ = self.args.python_argdefs() argdefs = self.add_numel_to_args(argdefs, signature) block_args = self.get_block_args()