[inductor] make multi-kernel work with cpp-wrapper (#117813)

Make multi-kernel work with cpp-wrapper. multi-kernel generates two equivalent variants for a reduction. At runtime the faster one is picked. But cpp-wrapper need save cubin file during codegen. They don't work with each other at the beginning.

Thanks Jason for suggesting a neat way to integrate these two. cpp-wrapper does 2 passes codegen right now. For the first pass, we still generate multi-kernel code and run it; for the second pass, we load the cubin file for the faster kernel directly. And multi-kernel python code is not generated for the second pass since they should not be needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117813
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang
2024-01-31 14:21:16 -08:00
committed by PyTorch MergeBot
parent 54668ad6dc
commit 20484a1936
5 changed files with 236 additions and 43 deletions

View File

@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Dict, List
from typing import Any, List
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
@ -18,15 +18,21 @@ def get_kernel_argdefs(kernel):
return arg_defs
def get_all_kernel_argdefs(kernels):
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
all_argdefs: Dict[
Any, None
] = {} # use a dict rather than set to maintain insertion order
for argdefs in argdefs_list:
all_argdefs.update(dict.fromkeys(argdefs))
def _get_all_args(args_list):
all_args = max(args_list, key=len)[:]
for args in args_list:
assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
return list(all_argdefs.keys())
return all_args
def get_all_kernel_argdefs(kernels):
"""
The logic here must match with `get_all_call_args`.
"""
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
return _get_all_args(argdefs_list)
def get_all_call_args(call_args_list):
@ -50,12 +56,7 @@ def get_all_call_args(call_args_list):
Instead, we pick the longest call args and assert that otehr call args are
a subset of it.
"""
all_call_args = max(call_args_list, key=len)[:]
for call_args in call_args_list:
assert set(call_args).issubset(
set(all_call_args)
), f"{call_args} v.s. {all_call_args}"
return all_call_args
return _get_all_args(call_args_list)
def get_numel_argdefs(kernel):
@ -101,6 +102,11 @@ class MultiKernelState:
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
if V.graph.cpp_wrapper:
# we should not generate any python code for multi-kernel during
# the second pass of cpp-wrapper.
return multi_kernel_name
wrapper = V.graph.wrapper_code
kernel_call_def_code = "\n".join(
@ -141,7 +147,7 @@ def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.j
""" # noqa: B950 line too long
wrapper.header.splice(
f"""
{multi_kernel_name} = async_compile.multi_kernel([
{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [
{", ".join(kernel_names)},
],
'''
@ -194,15 +200,25 @@ class MultiKernel:
all_call_args = get_all_call_args(call_args_list)
grid: List[Any] = []
if V.graph.cpp_wrapper:
# for the second pass of cpp-wrapper codegen, we should call
# the fast kernel directly
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
kernel_name = self.kernels[picked_kernel].kernel_name
final_call_args = call_args_list[picked_kernel]
else:
final_call_args = all_call_args
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args_and_grid(
kernel_name, all_call_args, grid
kernel_name, final_call_args, grid
)
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
V.graph.wrapper_code.generate_kernel_call(
self.kernel_name,
all_call_args,
kernel_name,
final_call_args,
grid,
V.graph.scheduler.current_device.index,
)
@ -249,9 +265,10 @@ class MultiKernelCall:
This class is called at run time to actually run the kernel
"""
def __init__(self, kernels, src_code):
def __init__(self, multi_kernel_name, kernels, src_code):
assert len(kernels) >= 2
self._kernels = kernels
self.multi_kernel_name = multi_kernel_name
self._run = PyCodeCache.load(src_code).run
self.disable_cache = os.environ.get(
@ -267,6 +284,8 @@ class MultiKernelCall:
elif not self.disable_cache:
self.load_cache()
self._recorded = False
def cache_file_path(self):
py_file_path = self._run.__globals__["__file__"]
return os.path.splitext(py_file_path)[0] + ".picked_kernel"
@ -323,6 +342,39 @@ class MultiKernelCall:
for kernel_call in kernel_calls
]
# record_choice and lookup_choice are helper functions for cpp-wrapper
# codegen. The first pass use record_choice to keep the choice and
# the second pass do lookup by calling lookup_choice.
#
# An alternative that reused the multi-kernel cache does not work well
# since during codegen of the second pass, it's very hard to know the
# path for the cache file. Also reading the cache file need do some IO
# which can be slower.
@staticmethod
def record_choice(multi_kernel_name, choice):
"""
Record the multi-kernel choice for cpp-wrapper first pass codegen
for the second pass.
We should do nothing if this function is not called during codegen.
"""
from torch._inductor.graph import GraphLowering
if not isinstance(V.graph, GraphLowering):
return
if not V.graph.record_multi_kernel_choice:
return
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
@staticmethod
def lookup_choice(multi_kernel_name):
# this should always been done during cpp-wrapper codegen
assert V.graph.record_multi_kernel_choice
# there should be no miss
return V.graph.multi_kernel_to_choice[multi_kernel_name]
def run_with_argless_kernels(self, kernel_calls):
if self.picked_kernel is None:
timings = self.benchmark_sub_kernels(kernel_calls)
@ -351,6 +403,11 @@ class MultiKernelCall:
"speedup": timings[1] / timings[0],
}
)
if not self.disable_cache:
self.store_cache()
if not self._recorded:
self._recorded = True
self.record_choice(self.multi_kernel_name, self.picked_kernel)
kernel_calls[self.picked_kernel]()