[inductor] Multi-kernel + cooperative reductions (#138893)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138893
Approved by: https://github.com/shunting314
ghstack dependencies: #138533
This commit is contained in:
Jason Ansel
2024-10-28 18:23:38 -07:00
committed by PyTorch MergeBot
parent 77b0ae832d
commit a762dc0357
7 changed files with 200 additions and 62 deletions

View File

@ -31,9 +31,10 @@ class CooperativeReductionTests(TestCase):
result, (source_code,) = run_and_get_code(fn, *args)
self.assertEqual(result, expected)
self.assertIn("@triton_heuristics.cooperative_reduction", source_code)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expect_kernel_count
)
if "async_compile.multi_kernel" not in source_code:
self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expect_kernel_count
)
return source_code
@parametrize(
@ -75,6 +76,8 @@ class CooperativeReductionTests(TestCase):
args = [torch.randn(1024, device="cuda") for _ in range(2)]
source_code = self.run_and_check(fn, args)
if "async_compile.multi_kernel" in source_code:
return
before, after = source_code.split("triton_helpers.x_grid_barrier")
self.assertEqual(before.count("if rsplit_id == ("), 0)
self.assertEqual(after.count("if rsplit_id == ("), 6)
@ -98,6 +101,8 @@ class CooperativeReductionTests(TestCase):
args = [torch.randn(4, 100000, device="cuda")]
source_code = self.run_and_check(fn, args)
if "async_compile.multi_kernel" in source_code:
return
self.assertEqual(source_code.count("triton_helpers.x_grid_barrier"), 16)
self.assertEqual(source_code.count("empty_strided_cuda"), 8)
@ -120,6 +125,11 @@ class NoPersistCooperativeReductionTests(CooperativeReductionTests):
pass
@config.patch("triton.multi_kernel", int(not config.triton.multi_kernel))
class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
pass
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -114,15 +114,11 @@ class WorkspaceArg:
@staticmethod
def maximum(a, b):
assert (
a.zero_mode == b.zero_mode
and a.dtype == b.dtype
and a.device == b.device
and a.inner_name == b.inner_name
and a.outer_name == b.outer_name
a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
)
return WorkspaceArg(
count=sympy.Max(a.count, b.count),
zero_mode=a.zero_mode,
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
dtype=a.dtype,
device=a.device,
inner_name=a.inner_name,

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import functools
import logging
import os
import pathlib
@ -8,11 +9,16 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.utils._ordered_set import OrderedSet
from .. import config
from ..codecache import get_path, TritonFuture
from ..codecache import code_hash, get_path, TritonFuture
from ..runtime.benchmarking import benchmarker
from ..runtime.triton_heuristics import (
cooperative_reduction_grid,
grid,
maybe_cooperative_reduction_grid,
)
from ..utils import cache_on_self, IndentedBuffer
from ..virtualized import V
from .common import TensorArg
from .common import TensorArg, WorkspaceArg
log = logging.getLogger(__name__)
@ -114,6 +120,7 @@ class MultiKernelState:
return multi_kernel_name
buf = IndentedBuffer()
buf.writeline("")
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
@ -155,6 +162,46 @@ class MultiKernel:
# attribute to decide if it's a non-null kernel.
self.args = object()
@staticmethod
def _merge_workspace_args(left: List[WorkspaceArg], right: List[WorkspaceArg]):
if left == right:
return left
result = {x.inner_name: x for x in left}
for arg in right:
if arg.inner_name in result:
result[arg.inner_name] = WorkspaceArg.maximum(
result[arg.inner_name], arg
)
else:
result[arg.inner_name] = arg
return [*result.values()]
@staticmethod
def merge_workspaces_inplace(kernels):
if len(kernels) < 2:
return
# All kernels must share the same workspace
workspace_args = functools.reduce(
MultiKernel._merge_workspace_args,
[kernel.args.workspace_args for kernel in kernels],
)
for kernel in kernels:
kernel.args.workspace_args = workspace_args
return workspace_args
def get_grid_fn(self):
fns = {kernel._get_grid_fn() for kernel in self.kernels}
if len(fns) == 1:
return next(iter(fns))
elif len(fns) == 2:
assert fns == {cooperative_reduction_grid, grid}
V.graph.wrapper_code.add_import_once(
f"from {maybe_cooperative_reduction_grid.__module__} import maybe_cooperative_reduction_grid"
)
return maybe_cooperative_reduction_grid
else:
raise NotImplementedError(fns)
def call_kernel(self, kernel_name):
"""
Collect the union of arguments from all subkernels as the arguments
@ -165,7 +212,7 @@ class MultiKernel:
_, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
for kernel in self.kernels[1:]:
_, other_call_args, _, other_arg_types = kernel.args.python_argdefs()
assert call_args == other_call_args
assert call_args == other_call_args, (call_args, other_call_args)
assert arg_types == other_arg_types
grid: List[Any] = []
@ -181,14 +228,24 @@ class MultiKernel:
kernel_name, call_args, arg_types, grid
)
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
for ws in self.kernels[0].args.workspace_args:
V.graph.wrapper_code.generate_workspace_allocation(ws)
grid_fn = self.get_grid_fn()
grid = V.graph.wrapper_code.generate_default_grid(
kernel_name, grid, grid_callable=grid_fn
)
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
call_args,
grid,
arg_types=arg_types,
grid_fn=grid_fn.__name__,
)
for ws in reversed(self.kernels[0].args.workspace_args):
V.graph.wrapper_code.generate_workspace_deallocation(ws)
def codegen_nan_check(self):
wrapper = V.graph.wrapper_code
seen = set()
@ -252,7 +309,8 @@ class MultiKernelCall:
self._recorded = False
def cache_file_path(self):
_, _, path = get_path(self.kernels[0].fn.cache_key, "picked_kernel")
key = code_hash(",".join([k.fn.cache_key for k in self.kernels]))
_, _, path = get_path(key, "picked_kernel")
return pathlib.Path(path)
def load_cache(self):
@ -359,22 +417,9 @@ class MultiKernelCall:
k0.inductor_meta.get("reduction_hint"),
timings,
)
def get_kernel_path(k):
return k.fn.fn.__code__.co_filename
get_metric_table("persistent_red_perf").add_row(
lambda: {
"kernel1_name": get_kernel_path(self.kernels[0]),
"kernel2_name": get_kernel_path(self.kernels[1]),
"kernel1_latency": timings[0],
"kernel2_latency": timings[1],
"size_hints": k0.size_hints,
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
"speedup": timings[1] / timings[0],
}
functools.partial(self._metrics_table_row, timings)
)
if not self.disable_cache:
self.store_cache()
@ -383,3 +428,23 @@ class MultiKernelCall:
self.record_choice(self.multi_kernel_name, self.picked_kernel)
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
self.run(*args, **kwargs)
def _metrics_table_row(self, timings):
def get_kernel_path(k):
return k.fn.fn.__code__.co_filename
k0 = self.kernels[0]
row = {
"size_hints": k0.size_hints,
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
}
max_kernels = 4
assert len(timings) <= max_kernels
for i in range(max_kernels):
if i < len(self.kernels):
row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i])
row[f"kernel{i}_latency"] = timings[i]
else:
row[f"kernel{i}_path"] = ""
row[f"kernel{i}_latency"] = ""
return row

View File

@ -321,6 +321,7 @@ class SIMDKernel(Kernel):
sexpr = pexpr
kexpr: Callable[[sympy.Expr], str]
allow_block_ptr = False
kernel_name: str
def __init__(
self,
@ -1366,7 +1367,7 @@ class SIMDScheduling(BaseScheduling):
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
# so taking the hit of non-coalesced loads is okay
if has_sort := schedule_contains_op(node_schedule, "sort"):
if schedule_contains_op(node_schedule, "sort"):
kernel_kwargs["override_persistent_reduction"] = True
kernel = kernel_type(
@ -1375,35 +1376,26 @@ class SIMDScheduling(BaseScheduling):
)
kernel.buf_accesses = buf_accesses
kernel2: Optional[SIMDKernel] = None
if kernel.persistent_reduction and config.triton.multi_kernel and not has_sort:
kernel2 = self.kernel_type(
*kernel_args,
**kernel_kwargs,
override_persistent_reduction=False,
)
self.codegen_node_schedule_with_kernel(node_schedule, kernel2)
with V.set_kernel_handler(kernel2):
src_code2 = kernel2.codegen_kernel()
kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel)
kernel2.kernel_name = kernel_name2
kernel2.code_hash = code_hash(src_code2)
kernels = self.add_multi_kernel_choices(
kernel, kernel_args, kernel_kwargs, node_schedule
)
for kernel in kernels:
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
MultiKernel.merge_workspaces_inplace(kernels)
for kernel in kernels:
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
kernel.kernel_name = kernel_name
kernel.code_hash = code_hash(src_code)
del kernel
# Keep buffers needed by the non-persistent reduction so both
# kernels have the same arguments
kernel.must_keep_buffers = set(kernel2.must_keep_buffers)
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
kernel.kernel_name = kernel_name
kernel.code_hash = code_hash(src_code)
final_kernel = MultiKernel([kernel, kernel2]) if kernel2 is not None else kernel
final_kernel: Union[SIMDKernel, MultiKernel]
if len(kernels) > 1:
final_kernel = MultiKernel(kernels)
else:
(final_kernel,) = kernels
with V.set_kernel_handler(final_kernel):
for node in node_schedule:
@ -1416,7 +1408,7 @@ class SIMDScheduling(BaseScheduling):
if config.nan_asserts:
final_kernel.codegen_nan_check()
if config.warn_mix_layout:
final_kernel.warn_mix_layout(kernel_name)
final_kernel.warn_mix_layout(kernels[0].kernel_name)
V.graph.removed_buffers |= final_kernel.removed_buffers
V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove
@ -1427,7 +1419,7 @@ class SIMDScheduling(BaseScheduling):
):
# Not every node in the schedule will actually be live on output;
# we can't check dead buffers.
live_outs = kernel.args.live_output_buffers()
live_outs = kernels[0].args.live_output_buffers()
for node in node_schedule:
if not isinstance(node, scheduler.BaseSchedulerNode):
continue
@ -1444,6 +1436,11 @@ class SIMDScheduling(BaseScheduling):
self.scheduler.free_buffers()
def add_multi_kernel_choices(
self, kernel, kernel_args, kernel_kwargs, node_schedule
) -> List[SIMDKernel]:
return [kernel]
def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
def current_reduction_nodes(nodes):
return itertools.takewhile(lambda n: n is not DisableReduction, nodes)

View File

@ -3613,6 +3613,60 @@ class TritonScheduling(SIMDScheduling):
store_cache()
return ms, mod.__file__
def add_multi_kernel_choices(
self,
kernel: SIMDKernel,
kernel_args: List[Any],
kernel_kwargs: Dict[str, Any],
node_schedule: List[BaseSchedulerNode],
) -> List[SIMDKernel]:
kernels: List[SIMDKernel] = [kernel]
if not config.triton.multi_kernel:
return kernels
optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get(
"override_persistent_reduction"
)
optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get(
"override_cooperative_reduction"
)
if optional_persistent:
kernels.append(
self.kernel_type(
*kernel_args,
**kernel_kwargs,
override_persistent_reduction=False,
)
)
if optional_cooperative:
_, rnumel = kernel.numels
# for larger sizes non-cooperative gets very slow
if V.graph.sizevars.statically_known_leq(rnumel, 65536):
kernels.append(
other := self.kernel_type(
*kernel_args,
**kernel_kwargs,
override_cooperative_reduction=False,
)
)
if optional_persistent and other.persistent_reduction:
kernels.append(
self.kernel_type(
*kernel_args,
**kernel_kwargs,
override_cooperative_reduction=False,
override_persistent_reduction=False,
)
)
if len(kernels) > 1:
for kernel2 in kernels[1:]:
# Keep buffers needed by the non-persistent reduction so both kernels have the same arguments
kernel2.must_keep_buffers = kernel.must_keep_buffers
# persistent kernels must be generated last so must_keep_buffers works right
kernels.sort(key=lambda k: k.persistent_reduction)
return kernels
def benchmark_combo_kernel(self, node_list):
def cache_file_path():
assert mod.__file__ is not None

View File

@ -212,13 +212,16 @@ MetricTable.register_table(
MetricTable.register_table(
"persistent_red_perf",
[
"kernel1_name",
"kernel2_name",
"kernel0_path",
"kernel1_path",
"kernel2_path",
"kernel3_path",
"kernel0_latency",
"kernel1_latency",
"kernel2_latency",
"kernel3_latency",
"size_hints",
"reduction_hint",
"speedup",
],
)

View File

@ -1889,6 +1889,19 @@ def cooperative_reduction_grid(xnumel):
return grid_fn
def maybe_cooperative_reduction_grid(xnumel):
def grid_fn(meta):
if "RSPLIT" in meta:
return coop_grid(meta)
return normal_grid(meta)
coop_grid = cooperative_reduction_grid(xnumel)
normal_grid = grid(xnumel)
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def split_scan_grid(xnumel, rnumel):
def grid_fn(meta):
assert meta.get("XBLOCK", 1) == 1