Revert "[inductor] make multi-kernel work with cpp-wrapper (#117813)"

This reverts commit c24ffc3f66b2270dfc65a404687b91b55ed580e9.

Reverted https://github.com/pytorch/pytorch/pull/117813 on behalf of https://github.com/atalman due to Failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/117813#issuecomment-1927877102))
This commit is contained in:
PyTorch MergeBot
2024-02-05 19:25:39 +00:00
parent b2e0f8d82d
commit b964a1222c
5 changed files with 41 additions and 235 deletions

View File

@ -1,6 +1,6 @@
import logging
import os
from typing import Any, List
from typing import Any, Dict, List
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
@ -18,21 +18,15 @@ def get_kernel_argdefs(kernel):
return arg_defs
def _get_all_args(args_list):
all_args = max(args_list, key=len)[:]
for args in args_list:
assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}"
return all_args
def get_all_kernel_argdefs(kernels):
"""
The logic here must match with `get_all_call_args`.
"""
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
all_argdefs: Dict[
Any, None
] = {} # use a dict rather than set to maintain insertion order
for argdefs in argdefs_list:
all_argdefs.update(dict.fromkeys(argdefs))
return _get_all_args(argdefs_list)
return list(all_argdefs.keys())
def get_all_call_args(call_args_list):
@ -56,7 +50,12 @@ def get_all_call_args(call_args_list):
Instead, we pick the longest call args and assert that otehr call args are
a subset of it.
"""
return _get_all_args(call_args_list)
all_call_args = max(call_args_list, key=len)[:]
for call_args in call_args_list:
assert set(call_args).issubset(
set(all_call_args)
), f"{call_args} v.s. {all_call_args}"
return all_call_args
def get_numel_argdefs(kernel):
@ -102,11 +101,6 @@ class MultiKernelState:
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
if V.graph.cpp_wrapper:
# we should not generate any python code for multi-kernel during
# the second pass of cpp-wrapper.
return multi_kernel_name
wrapper = V.graph.wrapper_code
kernel_call_def_code = "\n".join(
@ -147,7 +141,7 @@ def run(multi_kernel_call, {', '.join(get_all_kernel_argdefs(kernels))}, {', '.j
""" # noqa: B950 line too long
wrapper.header.splice(
f"""
{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [
{multi_kernel_name} = async_compile.multi_kernel([
{", ".join(kernel_names)},
],
'''
@ -200,25 +194,15 @@ class MultiKernel:
all_call_args = get_all_call_args(call_args_list)
grid: List[Any] = []
if V.graph.cpp_wrapper:
# for the second pass of cpp-wrapper codegen, we should call
# the fast kernel directly
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
kernel_name = self.kernels[picked_kernel].kernel_name
final_call_args = call_args_list[picked_kernel]
else:
final_call_args = all_call_args
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args_and_grid(
kernel_name, final_call_args, grid
kernel_name, all_call_args, grid
)
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
final_call_args,
self.kernel_name,
all_call_args,
grid,
V.graph.scheduler.current_device.index,
)
@ -265,10 +249,9 @@ class MultiKernelCall:
This class is called at run time to actually run the kernel
"""
def __init__(self, multi_kernel_name, kernels, src_code):
def __init__(self, kernels, src_code):
assert len(kernels) >= 2
self._kernels = kernels
self.multi_kernel_name = multi_kernel_name
self._run = PyCodeCache.load(src_code).run
self.disable_cache = os.environ.get(
@ -284,8 +267,6 @@ class MultiKernelCall:
elif not self.disable_cache:
self.load_cache()
self._recorded = False
def cache_file_path(self):
py_file_path = self._run.__globals__["__file__"]
return os.path.splitext(py_file_path)[0] + ".picked_kernel"
@ -342,39 +323,6 @@ class MultiKernelCall:
for kernel_call in kernel_calls
]
# record_choice and lookup_choice are helper functions for cpp-wrapper
# codegen. The first pass use record_choice to keep the choice and
# the second pass do lookup by calling lookup_choice.
#
# An alternative that reused the multi-kernel cache does not work well
# since during codegen of the second pass, it's very hard to know the
# path for the cache file. Also reading the cache file need do some IO
# which can be slower.
@staticmethod
def record_choice(multi_kernel_name, choice):
"""
Record the multi-kernel choice for cpp-wrapper first pass codegen
for the second pass.
We should do nothing if this function is not called during codegen.
"""
from torch._inductor.graph import GraphLowering
if not isinstance(V.graph, GraphLowering):
return
if not V.graph.record_multi_kernel_choice:
return
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
@staticmethod
def lookup_choice(multi_kernel_name):
# this should always been done during cpp-wrapper codegen
assert V.graph.record_multi_kernel_choice
# there should be no miss
return V.graph.multi_kernel_to_choice[multi_kernel_name]
def run_with_argless_kernels(self, kernel_calls):
if self.picked_kernel is None:
timings = self.benchmark_sub_kernels(kernel_calls)
@ -403,11 +351,6 @@ class MultiKernelCall:
"speedup": timings[1] / timings[0],
}
)
if not self.disable_cache:
self.store_cache()
if not self._recorded:
self._recorded = True
self.record_choice(self.multi_kernel_name, self.picked_kernel)
kernel_calls[self.picked_kernel]()