[AOTI] Fix multi-kernel codegen when using one-pass (#142333)

Summary: Update multi-kernel codegen to one-pass, following https://github.com/pytorch/pytorch/pull/141980.

Differential Revision: [D66936717](https://our.internmc.facebook.com/intern/diff/D66936717)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142333
Approved by: https://github.com/chenyang78
ghstack dependencies: #141980
This commit is contained in:
Bin Bao
2024-12-08 18:44:41 -08:00
committed by PyTorch MergeBot
parent 4d43ec2189
commit 5fc9f419ef
4 changed files with 34 additions and 22 deletions

View File

@ -115,7 +115,7 @@ class MultiKernelState:
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
if V.graph.cpp_wrapper:
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
# we should not generate any python code for multi-kernel during
# the second pass of cpp-wrapper.
return multi_kernel_name
@ -131,9 +131,11 @@ class MultiKernelState:
buf.writeline("])")
wrapper = V.graph.wrapper_code
wrapper.header.splice(buf)
if config.triton.autotune_at_compile_time:
wrapper.kernel_autotune_defs.splice(buf)
wrapper.src_to_kernel["\n".join(kernel_names)] = multi_kernel_name
else:
wrapper.header.splice(buf)
return multi_kernel_name
@ -218,11 +220,10 @@ class MultiKernel:
grid: List[Any] = []
if V.graph.cpp_wrapper:
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
# for the second pass of cpp-wrapper codegen, we should call
# the fast kernel directly
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
kernel_name = self.kernels[picked_kernel].kernel_name
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args_and_grid(
@ -381,10 +382,9 @@ class MultiKernelCall:
# path for the cache file. Also reading the cache file need do some IO
# which can be slower.
@staticmethod
def record_choice(multi_kernel_name, choice):
def record_choice(multi_kernel_name: str, picked_kernel_name: str):
"""
Record the multi-kernel choice for cpp-wrapper first pass codegen
for the second pass.
Record the multi-kernel choice for cpp-wrapper after autotuning
We should do nothing if this function is not called during codegen.
"""
@ -396,12 +396,15 @@ class MultiKernelCall:
if not V.graph.record_multi_kernel_choice:
return
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name
@staticmethod
def lookup_choice(multi_kernel_name):
def lookup_choice(multi_kernel_name: str) -> str:
# this should always been done during cpp-wrapper codegen
assert V.graph.record_multi_kernel_choice
assert (
V.graph.record_multi_kernel_choice
and multi_kernel_name in V.graph.multi_kernel_to_choice
)
# there should be no miss
return V.graph.multi_kernel_to_choice[multi_kernel_name]
@ -426,7 +429,11 @@ class MultiKernelCall:
if not self._recorded:
self._recorded = True
self.record_choice(self.multi_kernel_name, self.picked_kernel)
picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
"kernel_name"
)
assert picked_kernel_name is not None
self.record_choice(self.multi_kernel_name, picked_kernel_name)
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
self.run(*args, **kwargs)