[Inductor] Add triton.autotune support for user defined triton kernels with complex grids (#112290)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112290
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen
2023-10-28 23:19:05 -07:00
committed by PyTorch MergeBot
parent 5a1a9dc354
commit 1250032c2e
6 changed files with 176 additions and 36 deletions

View File

@ -689,31 +689,59 @@ class TritonKernelVariable(VariableTracker):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from triton.runtime.autotuner import Autotuner
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import BaseListVariable
grid = self.grid
if grid is None:
if self.grid is None:
raise Unsupported("Triton kernels should always be called with a grid")
# Both for grid's meta as well as for the kernel, we need combined
# args and kwargs normalized
normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
meta = ConstDictVariable(normalized_args, dict)
# If the grid is a function, then lets execute it and convert it to
# a list
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
# Populate the special "meta" argument to call the grid function
grid = grid.call_function(tx, [meta], {})
configs = (
[config.kwargs for config in self.kernel.configs]
if isinstance(self.kernel, Autotuner)
else [{}]
)
grids = []
for config_args in configs:
# If the grid is a function, then lets execute it and convert it to
# a list
grid = self.grid
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
# Populate the special "meta" argument to call the grid function
config_args = {
k: ConstantVariable.create(v) for k, v in config_args.items()
}
meta = ConstDictVariable({**normalized_args, **config_args}, dict)
grid = grid.call_function(tx, [meta], {})
# Now, the grid must be a list either originally or through above
# modification
if isinstance(grid, BaseListVariable):
grid = grid.as_proxy()
else:
unimplemented(f"grid for the triton kernel is {type(grid)}")
# Now, the grid must be a list either originally or through above
# modification
if isinstance(grid, BaseListVariable):
grids.append(grid.as_proxy())
else:
unimplemented(f"grid for the triton kernel is {type(grid)}")
for i in range(len(grids)):
if not isinstance(grids[i], tuple):
raise Unsupported("Only tuple grids are supported")
# inductor expects all grids to be 3-tuple so lets make it
if len(grids[i]) == 1:
grids[i] = (grids[i][0], 1, 1)
elif len(grids[i]) == 2:
grids[i] = (grids[i][0], grids[i][1], 1)
elif len(grids[i]) > 3:
raise Unsupported("Grid can have at most rank 3")
assert len(grids) != 0
if len(set(grids)) == 1:
# If there's only one unique grid, lets simplify
grids = [grids[0]]
from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_mutation,
@ -722,13 +750,14 @@ class TritonKernelVariable(VariableTracker):
# Combine args and kwargs and pass as a dict so that if user defined triton
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
# parameters of the wrapper function
meta = ConstDictVariable(normalized_args, dict)
tx.output.create_proxy(
"call_function",
triton_kernel_wrapper_mutation,
(),
{
"kernel_idx": self.kernel_idx,
"grid": grid,
"grid": grids,
"kwargs": meta.as_proxy(),
},
)