mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] make multi-kernel work with cpp-wrapper (#117813)
Make multi-kernel work with cpp-wrapper. multi-kernel generates two equivalent variants for a reduction. At runtime the faster one is picked. But cpp-wrapper need save cubin file during codegen. They don't work with each other at the beginning. Thanks Jason for suggesting a neat way to integrate these two. cpp-wrapper does 2 passes codegen right now. For the first pass, we still generate multi-kernel code and run it; for the second pass, we load the cubin file for the faster kernel directly. And multi-kernel python code is not generated for the second pass since they should not be needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117813 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
54668ad6dc
commit
20484a1936
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, List
|
||||
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
|
||||
@ -18,15 +18,21 @@ def get_kernel_argdefs(kernel):
|
||||
return arg_defs
|
||||
|
||||
|
||||
def get_all_kernel_argdefs(kernels):
|
||||
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
|
||||
all_argdefs: Dict[
|
||||
Any, None
|
||||
] = {} # use a dict rather than set to maintain insertion order
|
||||
for argdefs in argdefs_list:
|
||||
all_argdefs.update(dict.fromkeys(argdefs))
|
||||
def _get_all_args(args_list):
|
||||
all_args = max(args_list, key=len)[:]
|
||||
for args in args_list:
|
||||
assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
|
||||
|
||||
return list(all_argdefs.keys())
|
||||
return all_args
|
||||
|
||||
|
||||
def get_all_kernel_argdefs(kernels):
|
||||
"""
|
||||
The logic here must match with `get_all_call_args`.
|
||||
"""
|
||||
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
|
||||
|
||||
return _get_all_args(argdefs_list)
|
||||
|
||||
|
||||
def get_all_call_args(call_args_list):
|
||||
@ -50,12 +56,7 @@ def get_all_call_args(call_args_list):
|
||||
Instead, we pick the longest call args and assert that otehr call args are
|
||||
a subset of it.
|
||||
"""
|
||||
all_call_args = max(call_args_list, key=len)[:]
|
||||
for call_args in call_args_list:
|
||||
assert set(call_args).issubset(
|
||||
set(all_call_args)
|
||||
), f"{call_args} v.s. {all_call_args}"
|
||||
return all_call_args
|
||||
return _get_all_args(call_args_list)
|
||||
|
||||
|
||||
def get_numel_argdefs(kernel):
|
||||
@ -101,6 +102,11 @@ class MultiKernelState:
|
||||
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
|
||||
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
# we should not generate any python code for multi-kernel during
|
||||
# the second pass of cpp-wrapper.
|
||||
return multi_kernel_name
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
|
||||
kernel_call_def_code = "\n".join(
|
||||
@ -141,7 +147,7 @@ def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.j
|
||||
""" # noqa: B950 line too long
|
||||
wrapper.header.splice(
|
||||
f"""
|
||||
{multi_kernel_name} = async_compile.multi_kernel([
|
||||
{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [
|
||||
{", ".join(kernel_names)},
|
||||
],
|
||||
'''
|
||||
@ -194,15 +200,25 @@ class MultiKernel:
|
||||
all_call_args = get_all_call_args(call_args_list)
|
||||
grid: List[Any] = []
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
# for the second pass of cpp-wrapper codegen, we should call
|
||||
# the fast kernel directly
|
||||
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
|
||||
kernel_name = self.kernels[picked_kernel].kernel_name
|
||||
final_call_args = call_args_list[picked_kernel]
|
||||
else:
|
||||
final_call_args = all_call_args
|
||||
|
||||
# numels for all subkernels should be the same. Use kernels[0] here
|
||||
self.kernels[0].add_numel_to_call_args_and_grid(
|
||||
kernel_name, all_call_args, grid
|
||||
kernel_name, final_call_args, grid
|
||||
)
|
||||
|
||||
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
|
||||
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
self.kernel_name,
|
||||
all_call_args,
|
||||
kernel_name,
|
||||
final_call_args,
|
||||
grid,
|
||||
V.graph.scheduler.current_device.index,
|
||||
)
|
||||
@ -249,9 +265,10 @@ class MultiKernelCall:
|
||||
This class is called at run time to actually run the kernel
|
||||
"""
|
||||
|
||||
def __init__(self, kernels, src_code):
|
||||
def __init__(self, multi_kernel_name, kernels, src_code):
|
||||
assert len(kernels) >= 2
|
||||
self._kernels = kernels
|
||||
self.multi_kernel_name = multi_kernel_name
|
||||
|
||||
self._run = PyCodeCache.load(src_code).run
|
||||
self.disable_cache = os.environ.get(
|
||||
@ -267,6 +284,8 @@ class MultiKernelCall:
|
||||
elif not self.disable_cache:
|
||||
self.load_cache()
|
||||
|
||||
self._recorded = False
|
||||
|
||||
def cache_file_path(self):
|
||||
py_file_path = self._run.__globals__["__file__"]
|
||||
return os.path.splitext(py_file_path)[0] + ".picked_kernel"
|
||||
@ -323,6 +342,39 @@ class MultiKernelCall:
|
||||
for kernel_call in kernel_calls
|
||||
]
|
||||
|
||||
# record_choice and lookup_choice are helper functions for cpp-wrapper
|
||||
# codegen. The first pass use record_choice to keep the choice and
|
||||
# the second pass do lookup by calling lookup_choice.
|
||||
#
|
||||
# An alternative that reused the multi-kernel cache does not work well
|
||||
# since during codegen of the second pass, it's very hard to know the
|
||||
# path for the cache file. Also reading the cache file need do some IO
|
||||
# which can be slower.
|
||||
@staticmethod
|
||||
def record_choice(multi_kernel_name, choice):
|
||||
"""
|
||||
Record the multi-kernel choice for cpp-wrapper first pass codegen
|
||||
for the second pass.
|
||||
|
||||
We should do nothing if this function is not called during codegen.
|
||||
"""
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
if not isinstance(V.graph, GraphLowering):
|
||||
return
|
||||
|
||||
if not V.graph.record_multi_kernel_choice:
|
||||
return
|
||||
|
||||
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
|
||||
|
||||
@staticmethod
|
||||
def lookup_choice(multi_kernel_name):
|
||||
# this should always been done during cpp-wrapper codegen
|
||||
assert V.graph.record_multi_kernel_choice
|
||||
# there should be no miss
|
||||
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
||||
|
||||
def run_with_argless_kernels(self, kernel_calls):
|
||||
if self.picked_kernel is None:
|
||||
timings = self.benchmark_sub_kernels(kernel_calls)
|
||||
@ -351,6 +403,11 @@ class MultiKernelCall:
|
||||
"speedup": timings[1] / timings[0],
|
||||
}
|
||||
)
|
||||
|
||||
if not self.disable_cache:
|
||||
self.store_cache()
|
||||
|
||||
if not self._recorded:
|
||||
self._recorded = True
|
||||
self.record_choice(self.multi_kernel_name, self.picked_kernel)
|
||||
kernel_calls[self.picked_kernel]()
|
||||
|
Reference in New Issue
Block a user