[inductor][refactor] Unify the use of generate_kernel_call (#128467)

Summary: Refactor TritonTemplateKernel.call_kernel and ForeachKernel.call_kernel to use wrapper.generate_kernel_call to generate kernel calls instead of explicitly composing the kernel call string. This consolidates the entry point of generate_kernel_call and similifies later changes in this PR stack.

Differential Revision: [D58733631](https://our.internmc.facebook.com/intern/diff/D58733631)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128467
Approved by: https://github.com/shunting314
This commit is contained in:
Bin Bao
2024-06-18 08:30:26 -07:00
committed by PyTorch MergeBot
parent 3a185778ed
commit ba92f5277f
7 changed files with 56 additions and 78 deletions

View File

@ -198,11 +198,12 @@ class MultiKernel:
for the multi-kernel.
"""
assert kernel_name == self.kernel_name
call_args_list, arg_types = zip(
*[kernel.get_call_args() for kernel in self.kernels]
)
call_args_list = list(call_args_list)
arg_types_list = list(arg_types)
call_args_list = []
arg_types_list = []
for kernel in self.kernels:
_, call_args, _, arg_types = kernel.args.python_argdefs()
call_args_list.append(call_args)
arg_types_list.append(arg_types)
all_call_args, arg_types = get_all_call_args(call_args_list, arg_types_list)
grid: List[Any] = []
@ -223,12 +224,10 @@ class MultiKernel:
)
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
current_device = V.graph.scheduler.get_current_device_or_throw()
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
final_call_args,
grid,
current_device.index,
arg_types=arg_types,
)