mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][refactor] Unify the use of generate_kernel_call (#128467)
Summary: Refactor TritonTemplateKernel.call_kernel and ForeachKernel.call_kernel to use wrapper.generate_kernel_call to generate kernel calls instead of explicitly composing the kernel call string. This consolidates the entry point of generate_kernel_call and similifies later changes in this PR stack. Differential Revision: [D58733631](https://our.internmc.facebook.com/intern/diff/D58733631) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128467 Approved by: https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
3a185778ed
commit
ba92f5277f
@ -198,11 +198,12 @@ class MultiKernel:
|
||||
for the multi-kernel.
|
||||
"""
|
||||
assert kernel_name == self.kernel_name
|
||||
call_args_list, arg_types = zip(
|
||||
*[kernel.get_call_args() for kernel in self.kernels]
|
||||
)
|
||||
call_args_list = list(call_args_list)
|
||||
arg_types_list = list(arg_types)
|
||||
call_args_list = []
|
||||
arg_types_list = []
|
||||
for kernel in self.kernels:
|
||||
_, call_args, _, arg_types = kernel.args.python_argdefs()
|
||||
call_args_list.append(call_args)
|
||||
arg_types_list.append(arg_types)
|
||||
|
||||
all_call_args, arg_types = get_all_call_args(call_args_list, arg_types_list)
|
||||
grid: List[Any] = []
|
||||
@ -223,12 +224,10 @@ class MultiKernel:
|
||||
)
|
||||
|
||||
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
|
||||
current_device = V.graph.scheduler.get_current_device_or_throw()
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name,
|
||||
final_call_args,
|
||||
grid,
|
||||
current_device.index,
|
||||
arg_types=arg_types,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user