mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Fix ComboKernels failing due to missing helper functions (#162759)
Fixes: #162756 Differential Revision: [D82257359](https://our.internmc.facebook.com/intern/diff/D82257359) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162759 Approved by: https://github.com/eellison, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
38afeb2ba2
commit
fa4d5e76ea
@ -6,6 +6,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
|
from torch._inductor.utils import run_and_get_code
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
TestCase,
|
TestCase,
|
||||||
@ -554,6 +555,24 @@ class ComboKernelDynamicShapesTests(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(out_eager, out_compiled)
|
self.assertEqual(out_eager, out_compiled)
|
||||||
|
|
||||||
|
@requires_cuda_and_triton
|
||||||
|
def test_helper_fn_defined(self):
|
||||||
|
def fn(x, y, z):
|
||||||
|
return x.sum(1), y.mean(1), z.cumsum(1)
|
||||||
|
|
||||||
|
inps = (
|
||||||
|
torch.rand(16, 128, device="cuda"),
|
||||||
|
torch.rand(32, 128, device="cuda"),
|
||||||
|
torch.rand(32, 256, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_eager = fn(*inps)
|
||||||
|
fn_c = torch.compile(fn)
|
||||||
|
out_compiled, code = run_and_get_code(fn_c, *inps)
|
||||||
|
code = " ".join(code)
|
||||||
|
self.assertEqual(out_eager, out_compiled)
|
||||||
|
self.assertEqual(code.count("def _triton_helper_fn_add0(arg0_0, arg1_0):"), 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
@ -764,6 +764,14 @@ class ComboKernel(Kernel):
|
|||||||
if config.benchmark_combo_kernel:
|
if config.benchmark_combo_kernel:
|
||||||
code.splice(self.imports_for_benchmark_kernel())
|
code.splice(self.imports_for_benchmark_kernel())
|
||||||
|
|
||||||
|
seen_helpers: OrderedSet[str] = OrderedSet()
|
||||||
|
for sub_kernel in self.sub_kernels:
|
||||||
|
for helper in sub_kernel.helper_functions:
|
||||||
|
if helper not in seen_helpers:
|
||||||
|
code.writeline("")
|
||||||
|
code.splice(helper)
|
||||||
|
seen_helpers.add(helper)
|
||||||
|
|
||||||
argdefs, _, signature, _ = self.args.python_argdefs()
|
argdefs, _, signature, _ = self.args.python_argdefs()
|
||||||
argdefs = self.add_numel_to_args(argdefs, signature)
|
argdefs = self.add_numel_to_args(argdefs, signature)
|
||||||
block_args = self.get_block_args()
|
block_args = self.get_block_args()
|
||||||
|
Reference in New Issue
Block a user