mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor] Fix ComboKernels failing due to missing helper functions (#162759)
Fixes: #162756 Differential Revision: [D82257359](https://our.internmc.facebook.com/intern/diff/D82257359) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162759 Approved by: https://github.com/eellison, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
38afeb2ba2
commit
fa4d5e76ea
@ -6,6 +6,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TestCase,
|
||||
@ -554,6 +555,24 @@ class ComboKernelDynamicShapesTests(TestCase):
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_helper_fn_defined(self):
|
||||
def fn(x, y, z):
|
||||
return x.sum(1), y.mean(1), z.cumsum(1)
|
||||
|
||||
inps = (
|
||||
torch.rand(16, 128, device="cuda"),
|
||||
torch.rand(32, 128, device="cuda"),
|
||||
torch.rand(32, 256, device="cuda"),
|
||||
)
|
||||
|
||||
out_eager = fn(*inps)
|
||||
fn_c = torch.compile(fn)
|
||||
out_compiled, code = run_and_get_code(fn_c, *inps)
|
||||
code = " ".join(code)
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertEqual(code.count("def _triton_helper_fn_add0(arg0_0, arg1_0):"), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -764,6 +764,14 @@ class ComboKernel(Kernel):
|
||||
if config.benchmark_combo_kernel:
|
||||
code.splice(self.imports_for_benchmark_kernel())
|
||||
|
||||
seen_helpers: OrderedSet[str] = OrderedSet()
|
||||
for sub_kernel in self.sub_kernels:
|
||||
for helper in sub_kernel.helper_functions:
|
||||
if helper not in seen_helpers:
|
||||
code.writeline("")
|
||||
code.splice(helper)
|
||||
seen_helpers.add(helper)
|
||||
|
||||
argdefs, _, signature, _ = self.args.python_argdefs()
|
||||
argdefs = self.add_numel_to_args(argdefs, signature)
|
||||
block_args = self.get_block_args()
|
||||
|
Reference in New Issue
Block a user