mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Multi-kernel + cooperative reductions (#138893)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138893 Approved by: https://github.com/shunting314 ghstack dependencies: #138533
This commit is contained in:
committed by
PyTorch MergeBot
parent
77b0ae832d
commit
a762dc0357
@ -31,9 +31,10 @@ class CooperativeReductionTests(TestCase):
|
||||
result, (source_code,) = run_and_get_code(fn, *args)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertIn("@triton_heuristics.cooperative_reduction", source_code)
|
||||
self.assertEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expect_kernel_count
|
||||
)
|
||||
if "async_compile.multi_kernel" not in source_code:
|
||||
self.assertEqual(
|
||||
torch._inductor.metrics.generated_kernel_count, expect_kernel_count
|
||||
)
|
||||
return source_code
|
||||
|
||||
@parametrize(
|
||||
@ -75,6 +76,8 @@ class CooperativeReductionTests(TestCase):
|
||||
|
||||
args = [torch.randn(1024, device="cuda") for _ in range(2)]
|
||||
source_code = self.run_and_check(fn, args)
|
||||
if "async_compile.multi_kernel" in source_code:
|
||||
return
|
||||
before, after = source_code.split("triton_helpers.x_grid_barrier")
|
||||
self.assertEqual(before.count("if rsplit_id == ("), 0)
|
||||
self.assertEqual(after.count("if rsplit_id == ("), 6)
|
||||
@ -98,6 +101,8 @@ class CooperativeReductionTests(TestCase):
|
||||
|
||||
args = [torch.randn(4, 100000, device="cuda")]
|
||||
source_code = self.run_and_check(fn, args)
|
||||
if "async_compile.multi_kernel" in source_code:
|
||||
return
|
||||
self.assertEqual(source_code.count("triton_helpers.x_grid_barrier"), 16)
|
||||
self.assertEqual(source_code.count("empty_strided_cuda"), 8)
|
||||
|
||||
@ -120,6 +125,11 @@ class NoPersistCooperativeReductionTests(CooperativeReductionTests):
|
||||
pass
|
||||
|
||||
|
||||
@config.patch("triton.multi_kernel", int(not config.triton.multi_kernel))
|
||||
class MultiKernelCooperativeReductionTests(CooperativeReductionTests):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -114,15 +114,11 @@ class WorkspaceArg:
|
||||
@staticmethod
|
||||
def maximum(a, b):
|
||||
assert (
|
||||
a.zero_mode == b.zero_mode
|
||||
and a.dtype == b.dtype
|
||||
and a.device == b.device
|
||||
and a.inner_name == b.inner_name
|
||||
and a.outer_name == b.outer_name
|
||||
a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name
|
||||
)
|
||||
return WorkspaceArg(
|
||||
count=sympy.Max(a.count, b.count),
|
||||
zero_mode=a.zero_mode,
|
||||
zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode),
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
inner_name=a.inner_name,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
@ -8,11 +9,16 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config
|
||||
from ..codecache import get_path, TritonFuture
|
||||
from ..codecache import code_hash, get_path, TritonFuture
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.triton_heuristics import (
|
||||
cooperative_reduction_grid,
|
||||
grid,
|
||||
maybe_cooperative_reduction_grid,
|
||||
)
|
||||
from ..utils import cache_on_self, IndentedBuffer
|
||||
from ..virtualized import V
|
||||
from .common import TensorArg
|
||||
from .common import TensorArg, WorkspaceArg
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -114,6 +120,7 @@ class MultiKernelState:
|
||||
return multi_kernel_name
|
||||
|
||||
buf = IndentedBuffer()
|
||||
buf.writeline("")
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
@ -155,6 +162,46 @@ class MultiKernel:
|
||||
# attribute to decide if it's a non-null kernel.
|
||||
self.args = object()
|
||||
|
||||
@staticmethod
|
||||
def _merge_workspace_args(left: List[WorkspaceArg], right: List[WorkspaceArg]):
|
||||
if left == right:
|
||||
return left
|
||||
result = {x.inner_name: x for x in left}
|
||||
for arg in right:
|
||||
if arg.inner_name in result:
|
||||
result[arg.inner_name] = WorkspaceArg.maximum(
|
||||
result[arg.inner_name], arg
|
||||
)
|
||||
else:
|
||||
result[arg.inner_name] = arg
|
||||
return [*result.values()]
|
||||
|
||||
@staticmethod
|
||||
def merge_workspaces_inplace(kernels):
|
||||
if len(kernels) < 2:
|
||||
return
|
||||
# All kernels must share the same workspace
|
||||
workspace_args = functools.reduce(
|
||||
MultiKernel._merge_workspace_args,
|
||||
[kernel.args.workspace_args for kernel in kernels],
|
||||
)
|
||||
for kernel in kernels:
|
||||
kernel.args.workspace_args = workspace_args
|
||||
return workspace_args
|
||||
|
||||
def get_grid_fn(self):
|
||||
fns = {kernel._get_grid_fn() for kernel in self.kernels}
|
||||
if len(fns) == 1:
|
||||
return next(iter(fns))
|
||||
elif len(fns) == 2:
|
||||
assert fns == {cooperative_reduction_grid, grid}
|
||||
V.graph.wrapper_code.add_import_once(
|
||||
f"from {maybe_cooperative_reduction_grid.__module__} import maybe_cooperative_reduction_grid"
|
||||
)
|
||||
return maybe_cooperative_reduction_grid
|
||||
else:
|
||||
raise NotImplementedError(fns)
|
||||
|
||||
def call_kernel(self, kernel_name):
|
||||
"""
|
||||
Collect the union of arguments from all subkernels as the arguments
|
||||
@ -165,7 +212,7 @@ class MultiKernel:
|
||||
_, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
|
||||
for kernel in self.kernels[1:]:
|
||||
_, other_call_args, _, other_arg_types = kernel.args.python_argdefs()
|
||||
assert call_args == other_call_args
|
||||
assert call_args == other_call_args, (call_args, other_call_args)
|
||||
assert arg_types == other_arg_types
|
||||
|
||||
grid: List[Any] = []
|
||||
@ -181,14 +228,24 @@ class MultiKernel:
|
||||
kernel_name, call_args, arg_types, grid
|
||||
)
|
||||
|
||||
grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid)
|
||||
for ws in self.kernels[0].args.workspace_args:
|
||||
V.graph.wrapper_code.generate_workspace_allocation(ws)
|
||||
|
||||
grid_fn = self.get_grid_fn()
|
||||
grid = V.graph.wrapper_code.generate_default_grid(
|
||||
kernel_name, grid, grid_callable=grid_fn
|
||||
)
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name,
|
||||
call_args,
|
||||
grid,
|
||||
arg_types=arg_types,
|
||||
grid_fn=grid_fn.__name__,
|
||||
)
|
||||
|
||||
for ws in reversed(self.kernels[0].args.workspace_args):
|
||||
V.graph.wrapper_code.generate_workspace_deallocation(ws)
|
||||
|
||||
def codegen_nan_check(self):
|
||||
wrapper = V.graph.wrapper_code
|
||||
seen = set()
|
||||
@ -252,7 +309,8 @@ class MultiKernelCall:
|
||||
self._recorded = False
|
||||
|
||||
def cache_file_path(self):
|
||||
_, _, path = get_path(self.kernels[0].fn.cache_key, "picked_kernel")
|
||||
key = code_hash(",".join([k.fn.cache_key for k in self.kernels]))
|
||||
_, _, path = get_path(key, "picked_kernel")
|
||||
return pathlib.Path(path)
|
||||
|
||||
def load_cache(self):
|
||||
@ -359,22 +417,9 @@ class MultiKernelCall:
|
||||
k0.inductor_meta.get("reduction_hint"),
|
||||
timings,
|
||||
)
|
||||
|
||||
def get_kernel_path(k):
|
||||
return k.fn.fn.__code__.co_filename
|
||||
|
||||
get_metric_table("persistent_red_perf").add_row(
|
||||
lambda: {
|
||||
"kernel1_name": get_kernel_path(self.kernels[0]),
|
||||
"kernel2_name": get_kernel_path(self.kernels[1]),
|
||||
"kernel1_latency": timings[0],
|
||||
"kernel2_latency": timings[1],
|
||||
"size_hints": k0.size_hints,
|
||||
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
|
||||
"speedup": timings[1] / timings[0],
|
||||
}
|
||||
functools.partial(self._metrics_table_row, timings)
|
||||
)
|
||||
|
||||
if not self.disable_cache:
|
||||
self.store_cache()
|
||||
|
||||
@ -383,3 +428,23 @@ class MultiKernelCall:
|
||||
self.record_choice(self.multi_kernel_name, self.picked_kernel)
|
||||
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
||||
self.run(*args, **kwargs)
|
||||
|
||||
def _metrics_table_row(self, timings):
|
||||
def get_kernel_path(k):
|
||||
return k.fn.fn.__code__.co_filename
|
||||
|
||||
k0 = self.kernels[0]
|
||||
row = {
|
||||
"size_hints": k0.size_hints,
|
||||
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
|
||||
}
|
||||
max_kernels = 4
|
||||
assert len(timings) <= max_kernels
|
||||
for i in range(max_kernels):
|
||||
if i < len(self.kernels):
|
||||
row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i])
|
||||
row[f"kernel{i}_latency"] = timings[i]
|
||||
else:
|
||||
row[f"kernel{i}_path"] = ""
|
||||
row[f"kernel{i}_latency"] = ""
|
||||
return row
|
||||
|
@ -321,6 +321,7 @@ class SIMDKernel(Kernel):
|
||||
sexpr = pexpr
|
||||
kexpr: Callable[[sympy.Expr], str]
|
||||
allow_block_ptr = False
|
||||
kernel_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1366,7 +1367,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
|
||||
# so taking the hit of non-coalesced loads is okay
|
||||
if has_sort := schedule_contains_op(node_schedule, "sort"):
|
||||
if schedule_contains_op(node_schedule, "sort"):
|
||||
kernel_kwargs["override_persistent_reduction"] = True
|
||||
|
||||
kernel = kernel_type(
|
||||
@ -1375,35 +1376,26 @@ class SIMDScheduling(BaseScheduling):
|
||||
)
|
||||
kernel.buf_accesses = buf_accesses
|
||||
|
||||
kernel2: Optional[SIMDKernel] = None
|
||||
if kernel.persistent_reduction and config.triton.multi_kernel and not has_sort:
|
||||
kernel2 = self.kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
override_persistent_reduction=False,
|
||||
)
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel2)
|
||||
with V.set_kernel_handler(kernel2):
|
||||
src_code2 = kernel2.codegen_kernel()
|
||||
kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel)
|
||||
kernel2.kernel_name = kernel_name2
|
||||
kernel2.code_hash = code_hash(src_code2)
|
||||
kernels = self.add_multi_kernel_choices(
|
||||
kernel, kernel_args, kernel_kwargs, node_schedule
|
||||
)
|
||||
for kernel in kernels:
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
MultiKernel.merge_workspaces_inplace(kernels)
|
||||
for kernel in kernels:
|
||||
with V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
|
||||
kernel.kernel_name = kernel_name
|
||||
kernel.code_hash = code_hash(src_code)
|
||||
del kernel
|
||||
|
||||
# Keep buffers needed by the non-persistent reduction so both
|
||||
# kernels have the same arguments
|
||||
kernel.must_keep_buffers = set(kernel2.must_keep_buffers)
|
||||
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
|
||||
with V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
log.debug("Generating kernel code with kernel_name: %s", kernel_name)
|
||||
kernel.kernel_name = kernel_name
|
||||
kernel.code_hash = code_hash(src_code)
|
||||
|
||||
final_kernel = MultiKernel([kernel, kernel2]) if kernel2 is not None else kernel
|
||||
final_kernel: Union[SIMDKernel, MultiKernel]
|
||||
if len(kernels) > 1:
|
||||
final_kernel = MultiKernel(kernels)
|
||||
else:
|
||||
(final_kernel,) = kernels
|
||||
|
||||
with V.set_kernel_handler(final_kernel):
|
||||
for node in node_schedule:
|
||||
@ -1416,7 +1408,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
if config.nan_asserts:
|
||||
final_kernel.codegen_nan_check()
|
||||
if config.warn_mix_layout:
|
||||
final_kernel.warn_mix_layout(kernel_name)
|
||||
final_kernel.warn_mix_layout(kernels[0].kernel_name)
|
||||
|
||||
V.graph.removed_buffers |= final_kernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove
|
||||
@ -1427,7 +1419,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
):
|
||||
# Not every node in the schedule will actually be live on output;
|
||||
# we can't check dead buffers.
|
||||
live_outs = kernel.args.live_output_buffers()
|
||||
live_outs = kernels[0].args.live_output_buffers()
|
||||
for node in node_schedule:
|
||||
if not isinstance(node, scheduler.BaseSchedulerNode):
|
||||
continue
|
||||
@ -1444,6 +1436,11 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
self.scheduler.free_buffers()
|
||||
|
||||
def add_multi_kernel_choices(
|
||||
self, kernel, kernel_args, kernel_kwargs, node_schedule
|
||||
) -> List[SIMDKernel]:
|
||||
return [kernel]
|
||||
|
||||
def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
|
||||
def current_reduction_nodes(nodes):
|
||||
return itertools.takewhile(lambda n: n is not DisableReduction, nodes)
|
||||
|
@ -3613,6 +3613,60 @@ class TritonScheduling(SIMDScheduling):
|
||||
store_cache()
|
||||
return ms, mod.__file__
|
||||
|
||||
def add_multi_kernel_choices(
|
||||
self,
|
||||
kernel: SIMDKernel,
|
||||
kernel_args: List[Any],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
node_schedule: List[BaseSchedulerNode],
|
||||
) -> List[SIMDKernel]:
|
||||
kernels: List[SIMDKernel] = [kernel]
|
||||
if not config.triton.multi_kernel:
|
||||
return kernels
|
||||
|
||||
optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get(
|
||||
"override_persistent_reduction"
|
||||
)
|
||||
optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get(
|
||||
"override_cooperative_reduction"
|
||||
)
|
||||
if optional_persistent:
|
||||
kernels.append(
|
||||
self.kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
override_persistent_reduction=False,
|
||||
)
|
||||
)
|
||||
if optional_cooperative:
|
||||
_, rnumel = kernel.numels
|
||||
# for larger sizes non-cooperative gets very slow
|
||||
if V.graph.sizevars.statically_known_leq(rnumel, 65536):
|
||||
kernels.append(
|
||||
other := self.kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
override_cooperative_reduction=False,
|
||||
)
|
||||
)
|
||||
if optional_persistent and other.persistent_reduction:
|
||||
kernels.append(
|
||||
self.kernel_type(
|
||||
*kernel_args,
|
||||
**kernel_kwargs,
|
||||
override_cooperative_reduction=False,
|
||||
override_persistent_reduction=False,
|
||||
)
|
||||
)
|
||||
|
||||
if len(kernels) > 1:
|
||||
for kernel2 in kernels[1:]:
|
||||
# Keep buffers needed by the non-persistent reduction so both kernels have the same arguments
|
||||
kernel2.must_keep_buffers = kernel.must_keep_buffers
|
||||
# persistent kernels must be generated last so must_keep_buffers works right
|
||||
kernels.sort(key=lambda k: k.persistent_reduction)
|
||||
return kernels
|
||||
|
||||
def benchmark_combo_kernel(self, node_list):
|
||||
def cache_file_path():
|
||||
assert mod.__file__ is not None
|
||||
|
@ -212,13 +212,16 @@ MetricTable.register_table(
|
||||
MetricTable.register_table(
|
||||
"persistent_red_perf",
|
||||
[
|
||||
"kernel1_name",
|
||||
"kernel2_name",
|
||||
"kernel0_path",
|
||||
"kernel1_path",
|
||||
"kernel2_path",
|
||||
"kernel3_path",
|
||||
"kernel0_latency",
|
||||
"kernel1_latency",
|
||||
"kernel2_latency",
|
||||
"kernel3_latency",
|
||||
"size_hints",
|
||||
"reduction_hint",
|
||||
"speedup",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1889,6 +1889,19 @@ def cooperative_reduction_grid(xnumel):
|
||||
return grid_fn
|
||||
|
||||
|
||||
def maybe_cooperative_reduction_grid(xnumel):
|
||||
def grid_fn(meta):
|
||||
if "RSPLIT" in meta:
|
||||
return coop_grid(meta)
|
||||
return normal_grid(meta)
|
||||
|
||||
coop_grid = cooperative_reduction_grid(xnumel)
|
||||
normal_grid = grid(xnumel)
|
||||
grid_fn_str = f"maybe_cooperative_reduction_grid({xnumel})"
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
return grid_fn
|
||||
|
||||
|
||||
def split_scan_grid(xnumel, rnumel):
|
||||
def grid_fn(meta):
|
||||
assert meta.get("XBLOCK", 1) == 1
|
||||
|
Reference in New Issue
Block a user