mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Fix multi-kernel codegen when using one-pass (#142333)
Summary: Update multi-kernel codegen to one-pass, following https://github.com/pytorch/pytorch/pull/141980. Differential Revision: [D66936717](https://our.internmc.facebook.com/intern/diff/D66936717) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142333 Approved by: https://github.com/chenyang78 ghstack dependencies: #141980
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d43ec2189
commit
5fc9f419ef
@ -115,7 +115,7 @@ class MultiKernelState:
|
||||
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
|
||||
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
|
||||
# we should not generate any python code for multi-kernel during
|
||||
# the second pass of cpp-wrapper.
|
||||
return multi_kernel_name
|
||||
@ -131,9 +131,11 @@ class MultiKernelState:
|
||||
buf.writeline("])")
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
wrapper.header.splice(buf)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
wrapper.kernel_autotune_defs.splice(buf)
|
||||
wrapper.src_to_kernel["\n".join(kernel_names)] = multi_kernel_name
|
||||
else:
|
||||
wrapper.header.splice(buf)
|
||||
|
||||
return multi_kernel_name
|
||||
|
||||
@ -218,11 +220,10 @@ class MultiKernel:
|
||||
|
||||
grid: List[Any] = []
|
||||
|
||||
if V.graph.cpp_wrapper:
|
||||
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
|
||||
# for the second pass of cpp-wrapper codegen, we should call
|
||||
# the fast kernel directly
|
||||
picked_kernel = MultiKernelCall.lookup_choice(kernel_name)
|
||||
kernel_name = self.kernels[picked_kernel].kernel_name
|
||||
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
# numels for all subkernels should be the same. Use kernels[0] here
|
||||
self.kernels[0].add_numel_to_call_args_and_grid(
|
||||
@ -381,10 +382,9 @@ class MultiKernelCall:
|
||||
# path for the cache file. Also reading the cache file need do some IO
|
||||
# which can be slower.
|
||||
@staticmethod
|
||||
def record_choice(multi_kernel_name, choice):
|
||||
def record_choice(multi_kernel_name: str, picked_kernel_name: str):
|
||||
"""
|
||||
Record the multi-kernel choice for cpp-wrapper first pass codegen
|
||||
for the second pass.
|
||||
Record the multi-kernel choice for cpp-wrapper after autotuning
|
||||
|
||||
We should do nothing if this function is not called during codegen.
|
||||
"""
|
||||
@ -396,12 +396,15 @@ class MultiKernelCall:
|
||||
if not V.graph.record_multi_kernel_choice:
|
||||
return
|
||||
|
||||
V.graph.multi_kernel_to_choice[multi_kernel_name] = choice
|
||||
V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name
|
||||
|
||||
@staticmethod
|
||||
def lookup_choice(multi_kernel_name):
|
||||
def lookup_choice(multi_kernel_name: str) -> str:
|
||||
# this should always been done during cpp-wrapper codegen
|
||||
assert V.graph.record_multi_kernel_choice
|
||||
assert (
|
||||
V.graph.record_multi_kernel_choice
|
||||
and multi_kernel_name in V.graph.multi_kernel_to_choice
|
||||
)
|
||||
# there should be no miss
|
||||
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
||||
|
||||
@ -426,7 +429,11 @@ class MultiKernelCall:
|
||||
|
||||
if not self._recorded:
|
||||
self._recorded = True
|
||||
self.record_choice(self.multi_kernel_name, self.picked_kernel)
|
||||
picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
|
||||
"kernel_name"
|
||||
)
|
||||
assert picked_kernel_name is not None
|
||||
self.record_choice(self.multi_kernel_name, picked_kernel_name)
|
||||
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
||||
self.run(*args, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user