mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
Compare commits
1 Commits
cslpull89
...
mlazos/ck2
| Author | SHA1 | Date | |
|---|---|---|---|
| 0fef7b0ae2 |
262
test/inductor/test_combo_kernels.py
Normal file
262
test/inductor/test_combo_kernels.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import torch._inductor
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
try:
|
||||
from .test_torchinductor import check_model, check_model_cuda
|
||||
except ImportError:
|
||||
from test_torchinductor import check_model, check_model_cuda
|
||||
except (unittest.SkipTest, ImportError) as e:
|
||||
sys.stderr.write(f"{type(e)}: {e}\n")
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class ComboKernelTests(TestCase):
|
||||
check_model_cuda = check_model_cuda
|
||||
check_model_cpu = check_model
|
||||
check_kernel_count = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._inductor.metrics.reset()
|
||||
torch._inductor.config.combo_kernels = True
|
||||
torch._inductor.config.benchmark_combo_kernel = False
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
@requires_cuda
|
||||
def test_activation_functions(self):
|
||||
def test_activations(a, b, c):
|
||||
a1 = torch.nn.functional.relu(a)
|
||||
b1 = torch.nn.functional.sigmoid(b)
|
||||
c1 = torch.nn.functional.tanh(c)
|
||||
return a1, b1, c1
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_activations(*inps)
|
||||
out_compiled = torch.compile(test_activations)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
|
||||
@requires_cuda
|
||||
def test_reduce_functions(self):
|
||||
def test_reduce(a, b, c, d):
|
||||
a1 = torch.sum(a, dim=0)
|
||||
b1 = torch.max(b, dim=0)
|
||||
c1 = torch.min(c, dim=0)
|
||||
d1 = torch.nn.functional.tanh(d)
|
||||
|
||||
return a1, b1, c1, d1
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(30, 8, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_reduce(*inps)
|
||||
out_compiled = torch.compile(test_reduce)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2)
|
||||
|
||||
@requires_cuda
|
||||
def test_mutated_args(self):
|
||||
def test_mutated(a, b, c, d):
|
||||
a.add_(1)
|
||||
b.sigmoid_()
|
||||
c = torch.add(c, 5)
|
||||
d.tanh_()
|
||||
|
||||
return a, b, c, d
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(30, 8, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_mutated(*inps)
|
||||
out_compiled = torch.compile(test_mutated)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
|
||||
@requires_cuda
|
||||
def test_reduce_split(self):
|
||||
def fn(a, b):
|
||||
a1 = torch.linalg.vector_norm(a)
|
||||
b1 = torch.sum(b, dim=0)
|
||||
return a1, b1
|
||||
|
||||
inps = [
|
||||
torch.rand(2048, 512, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
]
|
||||
out_eager = fn(*inps)
|
||||
out_compiled = torch.compile(fn)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
|
||||
@requires_cuda
|
||||
def test_2d_blocking_partitioning(self):
|
||||
def fn(a0, a1, a2, b0, b1, b2):
|
||||
c0 = torch.add(a0, b0)
|
||||
c1 = torch.add(a1, b1)
|
||||
c2 = torch.add(a2, b2)
|
||||
return c0, c1, c2
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
(
|
||||
torch.rand(30, 20, device="cuda"),
|
||||
torch.rand(40, 30, device="cuda"),
|
||||
torch.rand(36, 40, device="cuda"),
|
||||
torch.rand(30, 20, device="cuda"),
|
||||
torch.rand(30, 40, device="cuda").t(),
|
||||
torch.rand(40, 36, device="cuda").t(),
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class ComboKernelBenchmarkTests(TestCase):
|
||||
check_model_cuda = check_model_cuda
|
||||
check_model_cpu = check_model
|
||||
check_kernel_count = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._inductor.metrics.reset()
|
||||
torch._inductor.config.combo_kernels = True
|
||||
torch._inductor.config.benchmark_combo_kernel = True
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
@requires_cuda
|
||||
def test_activation_benchmark(self):
|
||||
def test_activations(a, b, c):
|
||||
a1 = torch.nn.functional.relu(a)
|
||||
b1 = torch.nn.functional.sigmoid(b)
|
||||
c1 = torch.nn.functional.tanh(c)
|
||||
return a1, b1, c1
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_activations(*inps)
|
||||
out_compiled = torch.compile(test_activations)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
|
||||
|
||||
@requires_cuda
|
||||
def test_reduce_benchmark(self):
|
||||
def test_reduce(a, b, c, d):
|
||||
a1 = torch.sum(a, dim=0)
|
||||
b1 = torch.max(b, dim=0)
|
||||
c1 = torch.min(c, dim=0)
|
||||
d1 = torch.nn.functional.tanh(d)
|
||||
|
||||
return a1, b1, c1, d1
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(30, 8, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_reduce(*inps)
|
||||
out_compiled = torch.compile(test_reduce)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
|
||||
|
||||
@requires_cuda
|
||||
def test_mutated_benchmark(self):
|
||||
def test_mutated(a, b, c, d):
|
||||
a.add_(1)
|
||||
b.sigmoid_()
|
||||
c = torch.add(c, 5)
|
||||
d.tanh_()
|
||||
|
||||
return a, b, c, d
|
||||
|
||||
inps = [
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(20, 20, device="cuda"),
|
||||
torch.rand(10, 10, device="cuda"),
|
||||
torch.rand(30, 8, device="cuda"),
|
||||
]
|
||||
|
||||
out_eager = test_mutated(*inps)
|
||||
out_compiled = torch.compile(test_mutated)(*inps)
|
||||
|
||||
self.assertEqual(out_eager, out_compiled)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
|
||||
|
||||
@requires_cuda
|
||||
def test_2d_blocking_benchmark(self):
|
||||
def fn(a0, a1, a2, b0, b1, b2):
|
||||
c0 = torch.add(a0, b0)
|
||||
c1 = torch.add(a1, b1)
|
||||
c2 = torch.add(a2, b2)
|
||||
return c0, c1, c2
|
||||
|
||||
self.check_model_cuda(
|
||||
fn,
|
||||
(
|
||||
torch.rand(30, 20, device="cuda"),
|
||||
torch.rand(40, 30, device="cuda"),
|
||||
torch.rand(36, 40, device="cuda"),
|
||||
torch.rand(30, 20, device="cuda"),
|
||||
torch.rand(30, 40, device="cuda").t(),
|
||||
torch.rand(40, 36, device="cuda").t(),
|
||||
),
|
||||
)
|
||||
|
||||
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
if HAS_CPU or HAS_CUDA:
|
||||
run_tests(needs="filelock")
|
||||
@ -622,7 +622,7 @@ class ForeachTests(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
|
||||
|
||||
@requires_cuda
|
||||
@inplace_bin_ops
|
||||
|
||||
@ -132,10 +132,10 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
|
||||
|
||||
args = (x, y)
|
||||
|
||||
events = self._test_profiling_kernel_names(fn, args, "_for_")
|
||||
events = self._test_profiling_kernel_names(fn, args, "_poi_")
|
||||
event_found = False
|
||||
for event in events:
|
||||
if event.name == "triton_for_fused_0":
|
||||
if event.name == "triton_poi_fused_0":
|
||||
event_found = True
|
||||
self.assertTrue(
|
||||
event.input_shapes
|
||||
|
||||
@ -82,6 +82,7 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
arg_types=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import functools
|
||||
import os
|
||||
from itertools import chain, count
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
@ -159,16 +159,27 @@ class CppWrapperCuda(CppWrapperCpu):
|
||||
|
||||
return ", ".join(new_args)
|
||||
|
||||
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
|
||||
def generate_default_grid(
|
||||
self,
|
||||
name: str,
|
||||
grid: List[Any],
|
||||
cuda: bool = True,
|
||||
grid_callable: Optional[Callable[..., Any]] = None,
|
||||
**grid_extra_kwags,
|
||||
):
|
||||
"""
|
||||
Generate grid configs for launching a CUDA kernel using the grid
|
||||
function from triton_heuristics.
|
||||
"""
|
||||
if not cuda:
|
||||
return grid
|
||||
assert isinstance(grid, list), f"expected {grid=} to be a list"
|
||||
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
|
||||
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
|
||||
grid_fn = default_grid(*grid)
|
||||
grid_callable = grid_callable or default_grid
|
||||
if not grid_extra_kwags:
|
||||
grid_fn = grid_callable(*grid)
|
||||
else:
|
||||
grid_fn = grid_callable(*grid, **grid_extra_kwags)
|
||||
params = CudaKernelParamCache.get(name)
|
||||
assert (
|
||||
params is not None
|
||||
@ -191,6 +202,7 @@ class CppWrapperCuda(CppWrapperCpu):
|
||||
arg_types=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
if not cuda:
|
||||
# Even in CppWrapperCuda, we may see cpp kernels
|
||||
|
||||
@ -71,8 +71,8 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
def flush(self):
|
||||
return self._triton_scheduling.flush()
|
||||
|
||||
def codegen_foreach(self, *args, **kwargs):
|
||||
return self._triton_scheduling.codegen_foreach(*args, **kwargs)
|
||||
def codegen_combo_kernel(self, *args, **kwargs):
|
||||
return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
|
||||
|
||||
def benchmark_fused_nodes(self, nodes):
|
||||
return self._triton_scheduling.benchmark_fused_nodes(nodes)
|
||||
@ -81,3 +81,6 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._triton_scheduling.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel
|
||||
)
|
||||
|
||||
def benchmark_combo_kernel(self, node_list):
|
||||
return self._triton_scheduling.benchmark_combo_kernel(node_list)
|
||||
|
||||
@ -1296,6 +1296,7 @@ class TritonKernel(Kernel):
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
min_elem_per_thread=0,
|
||||
disable_persistent_reduction=False,
|
||||
optimize_mask=True,
|
||||
):
|
||||
if pid_cache is None:
|
||||
pid_cache = {}
|
||||
@ -1317,6 +1318,7 @@ class TritonKernel(Kernel):
|
||||
self.block_ptr_id = itertools.count()
|
||||
# buffer accesses in the kernel
|
||||
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
|
||||
self.optimize_mask = optimize_mask
|
||||
|
||||
self.persistent_reduction: bool = (
|
||||
not disable_persistent_reduction
|
||||
@ -1751,9 +1753,11 @@ class TritonKernel(Kernel):
|
||||
if isinstance(index, sympy.Integer):
|
||||
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
|
||||
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
|
||||
return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
|
||||
if self.optimize_mask:
|
||||
return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
|
||||
mask_vars = dense_mask_vars
|
||||
|
||||
if need_dense and not have_dense:
|
||||
elif need_dense and not have_dense:
|
||||
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
|
||||
index_str = f"tl.broadcast_to({index_str}, {expand_str})"
|
||||
mask_vars = dense_mask_vars
|
||||
@ -1785,6 +1789,8 @@ class TritonKernel(Kernel):
|
||||
return trees
|
||||
|
||||
def filter_masks(self, mask_vars):
|
||||
if not self.optimize_mask:
|
||||
return
|
||||
for tree in self.range_trees:
|
||||
# Masks are superfluous if we only have one element
|
||||
if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
|
||||
@ -3759,34 +3765,53 @@ class TritonScheduling(BaseScheduling):
|
||||
def codegen_sync(self):
|
||||
V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
|
||||
|
||||
def codegen_foreach(self, foreach_node):
|
||||
from .triton_foreach import ForeachKernel
|
||||
def codegen_combo_kernel(self, combo_kernel_node):
|
||||
from .triton_combo_kernel import ComboKernel
|
||||
|
||||
for partitions_with_metadata in ForeachKernel.horizontal_partition(
|
||||
foreach_node.get_subkernel_nodes(), self
|
||||
):
|
||||
kernel = ForeachKernel()
|
||||
for nodes, tiled_groups, numel, rnumel in partitions_with_metadata:
|
||||
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
|
||||
(
|
||||
reduction_hint_val,
|
||||
mutations,
|
||||
index_dtype,
|
||||
) = self.get_kernel_args(node_schedule, numel, rnumel)
|
||||
subkernel_nodes = combo_kernel_node.get_subkernel_nodes()
|
||||
fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
|
||||
subkernel_map, node_schedule_map = {}, {}
|
||||
for pn, nodes in zip(subkernel_nodes, fused_node_lists):
|
||||
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
||||
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
|
||||
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
|
||||
node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel
|
||||
(
|
||||
reduction_hint_val,
|
||||
mutations,
|
||||
index_dtype,
|
||||
) = self.get_kernel_args(node_schedule, numel, rnumel)
|
||||
subkernel_map[pn] = ComboKernel.create_triton_kernel(
|
||||
*tiled_groups,
|
||||
reduction_hint=reduction_hint_val,
|
||||
mutations=mutations,
|
||||
index_dtype=index_dtype,
|
||||
)
|
||||
|
||||
subkernel = kernel.create_sub_kernel(
|
||||
*tiled_groups,
|
||||
reduction_hint=reduction_hint_val,
|
||||
mutations=mutations,
|
||||
index_dtype=index_dtype,
|
||||
)
|
||||
partitions = ComboKernel.horizontal_partition(
|
||||
nodes=combo_kernel_node.get_subkernel_nodes(),
|
||||
triton_scheduling=self,
|
||||
custom_algorithm=combo_kernel_node.use_custom_partition_algo,
|
||||
kernel_map=subkernel_map,
|
||||
node_info_map=node_schedule_map,
|
||||
)
|
||||
log.debug(
|
||||
"ComboKernels: %d nodes partitioned into %s groups",
|
||||
len(combo_kernel_node.get_subkernel_nodes()),
|
||||
[len(p) for p in partitions],
|
||||
)
|
||||
for node_group in partitions:
|
||||
fused_node_lists = [node.get_nodes() for node in node_group]
|
||||
kernel = ComboKernel()
|
||||
|
||||
for pn, nodes in zip(node_group, fused_node_lists):
|
||||
self.codegen_node_schedule_with_kernel(
|
||||
node_schedule,
|
||||
subkernel,
|
||||
node_schedule_map[pn][0],
|
||||
kernel.create_sub_kernel(subkernel_map[pn]),
|
||||
)
|
||||
|
||||
with V.set_kernel_handler(subkernel):
|
||||
subkernel = subkernel_map[pn]
|
||||
node_schedule = node_schedule_map[pn][0]
|
||||
with V.set_kernel_handler(subkernel): # type: ignore[call-arg]
|
||||
for node in node_schedule:
|
||||
if node not in (EnableReduction, DisableReduction):
|
||||
node.mark_run()
|
||||
@ -3794,8 +3819,9 @@ class TritonScheduling(BaseScheduling):
|
||||
V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
|
||||
|
||||
src_code = kernel.codegen_kernel()
|
||||
kernel_name = self.define_kernel(src_code, [foreach_node])
|
||||
self.codegen_comment([foreach_node])
|
||||
kernel_name = self.define_kernel(src_code, [combo_kernel_node])
|
||||
self.codegen_comment([combo_kernel_node])
|
||||
log.debug("ComboKernels: generated kernel %s.", kernel_name)
|
||||
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
|
||||
|
||||
self.scheduler.free_buffers()
|
||||
@ -4062,6 +4088,134 @@ class TritonScheduling(BaseScheduling):
|
||||
store_cache()
|
||||
return ms, mod.__file__
|
||||
|
||||
def benchmark_combo_kernel(self, node_list):
|
||||
from .triton_combo_kernel import ComboKernel
|
||||
|
||||
def cache_file_path():
|
||||
assert mod.__file__ is not None
|
||||
return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
|
||||
|
||||
def load_cache():
|
||||
path = cache_file_path()
|
||||
if os.path.exists(path):
|
||||
with open(path) as fd:
|
||||
return tuple(float(e) for e in fd.read().split())
|
||||
return (None, None)
|
||||
|
||||
def store_cache():
|
||||
path = cache_file_path()
|
||||
with open(path, "w") as fd:
|
||||
fd.write(str(ms) + " " + str(ms_clone))
|
||||
|
||||
subkernel_nodes = node_list
|
||||
fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
|
||||
subkernel_map, node_schedule_map = {}, {}
|
||||
for pn, nodes in zip(subkernel_nodes, fused_node_lists):
|
||||
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
||||
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
|
||||
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
|
||||
(
|
||||
reduction_hint_val,
|
||||
mutations,
|
||||
index_dtype,
|
||||
) = self.get_kernel_args(node_schedule, numel, rnumel)
|
||||
node_schedule_map[pn] = (node_schedule, tiled_groups, numel, rnumel)
|
||||
subkernel_map[pn] = ComboKernel.create_triton_kernel(
|
||||
*tiled_groups,
|
||||
reduction_hint=reduction_hint_val,
|
||||
mutations=mutations,
|
||||
index_dtype=index_dtype,
|
||||
)
|
||||
|
||||
partitions = ComboKernel.horizontal_partition(
|
||||
nodes=subkernel_nodes,
|
||||
triton_scheduling=self,
|
||||
kernel_map=subkernel_map,
|
||||
node_info_map=node_schedule_map,
|
||||
custom_algorithm=True,
|
||||
)
|
||||
log.debug(
|
||||
"ComboKernels: %d nodes partitioned into %s groups",
|
||||
len(subkernel_nodes),
|
||||
[len(p) for p in partitions],
|
||||
)
|
||||
total_ms, file_list = 0, []
|
||||
total_clone_ms = 0
|
||||
removed_buffers_orig = V.graph.removed_buffers
|
||||
V.graph.removed_buffers = set(removed_buffers_orig)
|
||||
inplaced_to_remove_orig = V.graph.inplaced_to_remove
|
||||
V.graph.inplaced_to_remove = set(inplaced_to_remove_orig)
|
||||
for node_group in partitions:
|
||||
fused_node_lists = [node.get_nodes() for node in node_group]
|
||||
kernel = ComboKernel()
|
||||
names = [n.get_names() for nodes in fused_node_lists for n in nodes]
|
||||
|
||||
for pn, nodes in zip(node_group, fused_node_lists):
|
||||
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
|
||||
for n in nodes:
|
||||
n.last_usage = set()
|
||||
|
||||
self.codegen_node_schedule_with_kernel(
|
||||
node_schedule_map[pn][0],
|
||||
kernel.create_sub_kernel(subkernel_map[pn]),
|
||||
)
|
||||
subkernel = subkernel_map[pn]
|
||||
V.graph.removed_buffers |= subkernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
|
||||
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
|
||||
src_code = kernel.codegen_kernel()
|
||||
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
|
||||
mod = PyCodeCache.load(src_code)
|
||||
|
||||
log.debug(
|
||||
"kernel src code for %s written to: %s",
|
||||
names,
|
||||
mod.__file__,
|
||||
)
|
||||
ms, ms_clone = load_cache()
|
||||
if ms is not None:
|
||||
total_ms += ms
|
||||
total_clone_ms += ms_clone
|
||||
file_list.append(mod.__file__)
|
||||
continue
|
||||
|
||||
args = mod.get_args()
|
||||
call = mod.call
|
||||
wrapped_jit_function = mod.triton_
|
||||
|
||||
# call once to trigger the compilation
|
||||
call(wrapped_jit_function.clone_args(*args)[0])
|
||||
|
||||
launchers = wrapped_jit_function.launchers
|
||||
assert len(launchers) == 1
|
||||
if launchers[0].n_spills > 0:
|
||||
# skip benchmarking the kernel if there are register spills
|
||||
ms = ms_clone = float("inf")
|
||||
else:
|
||||
# We have to clone the inplace updated arguments to avoid earlier calls
|
||||
# generating out of range indices for later calls.
|
||||
ms = do_bench_gpu(
|
||||
lambda: call(wrapped_jit_function.clone_args(*args)[0])
|
||||
)
|
||||
ms_clone = do_bench_gpu(
|
||||
lambda: wrapped_jit_function.clone_args(*args)[0]
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs",
|
||||
{n.get_name() for n in nodes},
|
||||
ms,
|
||||
ms_clone,
|
||||
)
|
||||
store_cache()
|
||||
total_ms += ms
|
||||
total_clone_ms += ms_clone
|
||||
file_list.append(mod.__file__)
|
||||
V.graph.removed_buffers = removed_buffers_orig
|
||||
V.graph.inplaced_to_remove = inplaced_to_remove_orig
|
||||
return total_ms, total_clone_ms, file_list
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CandidateTiling:
|
||||
|
||||
675
torch/_inductor/codegen/triton_combo_kernel.py
Normal file
675
torch/_inductor/codegen/triton_combo_kernel.py
Normal file
@ -0,0 +1,675 @@
|
||||
import itertools
|
||||
import logging
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from sympy import Integer
|
||||
|
||||
from .. import config, metrics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
from ..runtime.runtime_utils import next_power_of_2
|
||||
from ..runtime.triton_heuristics import grid_combo_kernels
|
||||
from ..scheduler import FusedSchedulerNode
|
||||
from ..utils import Placeholder
|
||||
from ..virtualized import V
|
||||
from .common import DeferredLine, IndentedBuffer, Kernel, PythonPrinter, SizeArg
|
||||
from .triton import gen_common_triton_imports, TritonKernel
|
||||
from .triton_utils import config_of, signature_to_meta
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
pexpr = PythonPrinter().doprint
|
||||
LARGE_NUMELS = 512e5
|
||||
|
||||
|
||||
def _default_custom_combo_kernel_horizontal_partition(
|
||||
nodes, triton_scheduling, kernel_map, node_info_map
|
||||
):
|
||||
"""Horizontally partition the given list of nodes into a list of list of nodes where each sublist
|
||||
represents a partion. Nodes in different partitions are implemented in different combo kernels.
|
||||
Nodes in the same partition are likely to be implemented
|
||||
in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
|
||||
|
||||
Input arguments:
|
||||
nodes: a list of fused scheduler nodes to partition.
|
||||
triton_scheduling: TritonScheduling instance.
|
||||
kernel_map: a map from node to its kernel.
|
||||
node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
|
||||
Oputput:
|
||||
a list of list of nodes with each sublist representing a partition.
|
||||
|
||||
The default algorithm is to partition nodes based on the following rules:
|
||||
1) nodes with the same number of block dimensions are grouped togather.
|
||||
2) large pointwise nodes are separated from other nodes.
|
||||
3) large reduce nodes are separated from other nodes.
|
||||
"""
|
||||
|
||||
assert len(nodes) >= 1
|
||||
|
||||
# first partition nodes based on number of block dimensions
|
||||
tilings = [node_info_map[n][1] for n in nodes]
|
||||
|
||||
max_dims = max(len(t) for t in tilings)
|
||||
nodes_per_ndim = []
|
||||
for i in range(2, max_dims + 1):
|
||||
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
|
||||
reduction = [
|
||||
n
|
||||
for n in group_per_dim
|
||||
if kernel_map[n].inside_reduction
|
||||
and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
|
||||
]
|
||||
not_reduction = [n for n in group_per_dim if n not in reduction]
|
||||
# rnumel > 2048 usually has long execution time
|
||||
long_reduction = [n for n in reduction if n.group[-1][-1] > 2048]
|
||||
short_reduction = [n for n in reduction if n not in long_reduction]
|
||||
if long_reduction:
|
||||
log.warning(
|
||||
"ComboKernels: %d long reduction nodes are separated",
|
||||
len(long_reduction),
|
||||
)
|
||||
large_pointwise = [
|
||||
n
|
||||
for n in not_reduction
|
||||
if not kernel_map[n].inside_reduction
|
||||
and len(kernel_map[n].numels) == 2
|
||||
and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS
|
||||
]
|
||||
if large_pointwise:
|
||||
# TODO benchmark the performance when large pointwise nodes combining with others
|
||||
log.warning(
|
||||
"ComboKernels: %d large pointwise nodes are separated",
|
||||
len(large_pointwise),
|
||||
)
|
||||
not_reduction = [n for n in not_reduction if n not in large_pointwise]
|
||||
for node in large_pointwise:
|
||||
nodes_per_ndim.append([node])
|
||||
|
||||
for g in (not_reduction, short_reduction, long_reduction):
|
||||
if g:
|
||||
nodes_per_ndim.append(g)
|
||||
|
||||
assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
|
||||
return nodes_per_ndim
|
||||
|
||||
|
||||
_custom_combo_kernel_horizontal_partition_algorithm = (
|
||||
_default_custom_combo_kernel_horizontal_partition
|
||||
)
|
||||
|
||||
|
||||
def set_custom_combo_kernel_horizontal_partition(algorithm):
|
||||
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
|
||||
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
|
||||
in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
|
||||
|
||||
The algorithm should take a list of nodes and return a list of list of nodes.
|
||||
|
||||
The default algorithm is to partition nodes based on number of block dimensions.
|
||||
"""
|
||||
global _custom_combo_kernel_horizontal_partition_algorithm
|
||||
_custom_combo_kernel_horizontal_partition_algorithm = algorithm
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartitionState:
|
||||
partitions: List[List[FusedSchedulerNode]]
|
||||
cur_partition: List[FusedSchedulerNode]
|
||||
cur_count: int
|
||||
|
||||
def finalize(self):
|
||||
if self.cur_partition:
|
||||
self.partitions.append(self.cur_partition)
|
||||
|
||||
|
||||
class ComboKernel(Kernel):
|
||||
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
|
||||
|
||||
@staticmethod
|
||||
def _update_partition(partition_state, node_rw_count, node_info):
|
||||
if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
|
||||
partition_state.partitions.append(partition_state.cur_partition)
|
||||
partition_state.cur_partition = [node_info]
|
||||
partition_state.cur_count = node_rw_count
|
||||
else:
|
||||
partition_state.cur_count += node_rw_count
|
||||
partition_state.cur_partition.append(node_info)
|
||||
|
||||
@staticmethod
|
||||
def _base_horizontal_partition(subkernel_nodes, triton_scheduling, node_info_map):
|
||||
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
|
||||
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
|
||||
(read/writes) and to have the same 2D or 1D blocking strategy."""
|
||||
# TODO support combination of kernels with different block dimensions
|
||||
assert len(subkernel_nodes) >= 1
|
||||
|
||||
ndim_to_partition_state: Dict[int, PartitionState] = defaultdict(
|
||||
lambda: PartitionState([], [], 0)
|
||||
)
|
||||
|
||||
for node in subkernel_nodes:
|
||||
node_schedule, tiled_groups, numel, rnumel = node_info_map[node]
|
||||
node_info = node
|
||||
|
||||
read_writes = node.read_writes
|
||||
read_write_count = len(read_writes.reads) + len(read_writes.writes)
|
||||
|
||||
ndim = len(tiled_groups)
|
||||
partition_state = ndim_to_partition_state[ndim]
|
||||
ComboKernel._update_partition(partition_state, read_write_count, node_info)
|
||||
|
||||
all_partitions = []
|
||||
for partition_state in ndim_to_partition_state.values():
|
||||
partition_state.finalize()
|
||||
all_partitions.extend(partition_state.partitions)
|
||||
|
||||
return all_partitions
|
||||
|
||||
@staticmethod
|
||||
def horizontal_partition(
|
||||
nodes, triton_scheduling, kernel_map, node_info_map, custom_algorithm=False
|
||||
):
|
||||
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
|
||||
for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
|
||||
sublists in the following way:
|
||||
1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
|
||||
2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
|
||||
guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
|
||||
2D or 1D blocking strategy.
|
||||
"""
|
||||
if custom_algorithm:
|
||||
raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
|
||||
nodes, triton_scheduling, kernel_map, node_info_map
|
||||
)
|
||||
else:
|
||||
raw_partitions = [nodes]
|
||||
|
||||
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
|
||||
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
|
||||
(read/writes) and to have the same 2D or 1D blocking strategy."""
|
||||
all_partitions = []
|
||||
for raw_partition in raw_partitions:
|
||||
all_partitions.extend(
|
||||
ComboKernel._base_horizontal_partition(
|
||||
raw_partition, triton_scheduling, node_info_map
|
||||
)
|
||||
)
|
||||
return all_partitions
|
||||
|
||||
def __init__(self, use_custom_partition_algo=False):
|
||||
super().__init__()
|
||||
self.sub_kernels = []
|
||||
self.iter_vars_count = itertools.count()
|
||||
self.block_count = 0
|
||||
self.grids = []
|
||||
self.min_x_blocks_list = []
|
||||
self.use_custom_partition_algo = use_custom_partition_algo
|
||||
|
||||
def codegen_pid_range(self, code, num):
|
||||
num_kernels = len(self.sub_kernels)
|
||||
if self.block_count == 0:
|
||||
cond = "if"
|
||||
else:
|
||||
cond = "elif"
|
||||
code.splice(f"{cond} pid % {num_kernels} == {num}:")
|
||||
with code.indent():
|
||||
code.splice(f"pid_offset = pid // {num_kernels}")
|
||||
self.block_count += 1
|
||||
|
||||
def create_sub_kernel(self, triton_kernel):
|
||||
sub_kernel = triton_kernel
|
||||
metrics.generated_kernel_count -= 1
|
||||
sub_kernel.args = self.args
|
||||
sub_kernel.iter_vars_count = self.iter_vars_count
|
||||
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
|
||||
self.sub_kernels.append(sub_kernel)
|
||||
return sub_kernel
|
||||
|
||||
@staticmethod
|
||||
def create_triton_kernel(*groups, index_dtype, mutations, reduction_hint):
|
||||
return TritonKernel(
|
||||
*groups,
|
||||
index_dtype=index_dtype,
|
||||
mutations=mutations,
|
||||
pid_cache={"tl.program_id(0)": "pid_offset"},
|
||||
reduction_hint=reduction_hint,
|
||||
optimize_mask=False,
|
||||
)
|
||||
|
||||
def codegen_static_numels_sub_kernel(self, code, sub_kernel, num):
|
||||
"""
|
||||
We get a small speedup from hard coding numels if they are static.
|
||||
|
||||
This code stomps on the passed-in values by writing an constant to the top of the kernel.
|
||||
|
||||
In a kernel like:
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
|
||||
|
||||
We would add
|
||||
xnumel = 4096
|
||||
rnumel = 768
|
||||
|
||||
After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
|
||||
a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
|
||||
knows that its a static numel, as that you just plop a constant into the kernel.
|
||||
"""
|
||||
grid = []
|
||||
uniquify_block_sizes = []
|
||||
for tree in sub_kernel.range_trees:
|
||||
simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
|
||||
code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
|
||||
|
||||
if tree.prefix != "r":
|
||||
grid.append(int(simplified_tree_numel))
|
||||
|
||||
if tree.prefix == "r" and sub_kernel.persistent_reduction:
|
||||
if isinstance(simplified_tree_numel, (Integer, int)):
|
||||
val = int(simplified_tree_numel)
|
||||
else:
|
||||
continue
|
||||
val = next_power_of_2(val)
|
||||
code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
|
||||
uniquify_block_sizes.append("RBLOCK")
|
||||
|
||||
if tree.prefix == "x" and sub_kernel.no_x_dim:
|
||||
code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
|
||||
uniquify_block_sizes.append("XBLOCK")
|
||||
self.grids.append(grid)
|
||||
return uniquify_block_sizes
|
||||
|
||||
def min_x_blocks_sub_kernel(self, sub_kernel, num):
|
||||
"""
|
||||
We get a small speedup from hard coding numels if they are static.
|
||||
|
||||
This code stomps on the passed-in values by writing an constant to the top of the kernel.
|
||||
|
||||
In a kernel like:
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
|
||||
|
||||
We would add
|
||||
xnumel = 4096
|
||||
rnumel = 768
|
||||
|
||||
After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
|
||||
a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
|
||||
knows that its a static numel, as that you just plop a constant into the kernel.
|
||||
"""
|
||||
min_x_blocks = 0
|
||||
for tree in sub_kernel.range_trees:
|
||||
simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
|
||||
if tree.prefix == "x" and sub_kernel.no_x_dim:
|
||||
min_x_blocks = int(simplified_tree_numel)
|
||||
self.min_x_blocks_list.append(min_x_blocks)
|
||||
|
||||
def select_heuristics(self, sub_kernel) -> Tuple[str, List[int]]:
|
||||
size_hints = [
|
||||
next_power_of_2(V.graph.sizevars.size_hint(numel))
|
||||
for numel in sub_kernel.numels
|
||||
]
|
||||
if sub_kernel.persistent_reduction:
|
||||
assert sub_kernel.inside_reduction
|
||||
heuristics = "persistent_reduction"
|
||||
elif sub_kernel.inside_reduction:
|
||||
heuristics = "reduction"
|
||||
else:
|
||||
size_hints.pop()
|
||||
heuristics = "pointwise"
|
||||
return heuristics, size_hints
|
||||
|
||||
def select_combo_heuristics(self, heuristics_list, size_hints_list):
|
||||
if "reduction" in heuristics_list:
|
||||
i, _ = max(
|
||||
enumerate(size_hints_list),
|
||||
key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0,
|
||||
)
|
||||
return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
|
||||
elif "pointwise" in heuristics_list:
|
||||
i, _ = max(
|
||||
enumerate(size_hints_list),
|
||||
key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0,
|
||||
)
|
||||
# modify size_hint to avoid oom check fail (may be a false alarm)
|
||||
num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
|
||||
num_reduction = len([e for e in heuristics_list if e == "reduction"])
|
||||
num_persistent_reduction = len(
|
||||
[e for e in heuristics_list if e == "persistent_reduction"]
|
||||
)
|
||||
assert (
|
||||
num_reduction == 0
|
||||
), "combining pointwise and reduction are not supported yet."
|
||||
heuristics = (
|
||||
"pointwise_with_reduction"
|
||||
if num_persistent_reduction > 0
|
||||
else "pointwise"
|
||||
)
|
||||
if len(heuristics_list) - num_pointwise >= 4:
|
||||
size_hints = size_hints_list[i]
|
||||
size_hints[0] = min(128, size_hints[0])
|
||||
return heuristics, size_hints_list[i], self.sub_kernels[i]
|
||||
else:
|
||||
return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
|
||||
|
||||
def get_mutated_args_sub_kernels(self) -> List[str]:
|
||||
mutated_args = set()
|
||||
for sub_kernel in self.sub_kernels:
|
||||
for mutation in sub_kernel.mutations:
|
||||
if mutation in sub_kernel.args.input_buffers:
|
||||
mutated_args.add(sub_kernel.args.input_buffers[mutation])
|
||||
if (
|
||||
mutation in sub_kernel.args.inplace_buffers
|
||||
and mutation not in V.graph.removed_buffers
|
||||
and mutation not in sub_kernel.removed_buffers
|
||||
):
|
||||
mutated_args.add(
|
||||
sub_kernel.args.inplace_buffers[mutation].inner_name
|
||||
)
|
||||
if mutation in sub_kernel.args.output_buffers:
|
||||
mutated_args.add(sub_kernel.args.output_buffers[mutation])
|
||||
return sorted(mutated_args)
|
||||
|
||||
def jit_line(
|
||||
self, heuristics, size_hints, selected_kernel, pointwise_with_reduce=False
|
||||
):
|
||||
# TODO: is it correct to use the first sub kernel's heuristics?
|
||||
_, _, signature = self.args.python_argdefs()
|
||||
# TODO Is it ok to just use sub_kernel[0].index_dtype?
|
||||
index_dtype = self.sub_kernels[0].index_dtype
|
||||
for i, sub in enumerate(self.sub_kernels):
|
||||
self.min_x_blocks_sub_kernel(sub, i)
|
||||
triton_meta = {
|
||||
"signature": signature_to_meta(signature, size_dtype=index_dtype),
|
||||
"device": DeviceProperties.create(V.graph.scheduler.current_device),
|
||||
"constants": {},
|
||||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
mutated_args = self.get_mutated_args_sub_kernels()
|
||||
inductor_meta = {
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
"mutated_arg_names": mutated_args,
|
||||
**TritonKernel.inductor_meta_common(),
|
||||
}
|
||||
|
||||
sub_kernel = selected_kernel
|
||||
if sub_kernel.inside_reduction:
|
||||
reduction_hint = sub_kernel.reduction_hint
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.{heuristics}(
|
||||
size_hints={size_hints!r},
|
||||
reduction_hint={reduction_hint},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta}
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
else:
|
||||
tile_hint = ""
|
||||
if len(size_hints) == 2:
|
||||
# TODO only input, output and 2 args
|
||||
tile_hint = "tile_hint=TileHint.SQUARE,"
|
||||
else:
|
||||
tile_hint = "tile_hint=TileHint.DEFAULT,"
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.{heuristics}(
|
||||
size_hints={size_hints!r}, {tile_hint}
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r}
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
|
||||
return heuristics_line
|
||||
|
||||
def add_blockd_to_args(self, argdefs):
|
||||
block_args = {}
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
# TODO: we assume all sub_kernels have the same block size
|
||||
for tree in sub_kernel.range_trees:
|
||||
if tree.prefix == "r" and (
|
||||
not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
|
||||
):
|
||||
continue
|
||||
if tree.prefix == "x" and sub_kernel.no_x_dim:
|
||||
continue
|
||||
# argdefs.append(f"{tree.prefix.upper()}BLOCK_{num} : tl.constexpr")
|
||||
block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
|
||||
for arg in block_args:
|
||||
argdefs.append(arg)
|
||||
return argdefs
|
||||
|
||||
def codegen_kernel(self, name=None):
|
||||
# TODO: is it correct to use the first sub kernel's heuristics?
|
||||
heuristics_list, size_hints_list = [], []
|
||||
for subkernel in self.sub_kernels:
|
||||
h, s = self.select_heuristics(subkernel)
|
||||
heuristics_list.append(h)
|
||||
size_hints_list.append(s)
|
||||
heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
|
||||
heuristics_list, size_hints_list
|
||||
)
|
||||
pointwise_with_reduction, heuristics = (
|
||||
(True, "pointwise")
|
||||
if heuristics == "pointwise_with_reduction"
|
||||
else (False, heuristics)
|
||||
)
|
||||
code = IndentedBuffer()
|
||||
|
||||
code.splice(gen_common_triton_imports())
|
||||
if config.benchmark_combo_kernel:
|
||||
code.splice(self.imports_for_benchmark_kernel())
|
||||
|
||||
argdefs, _, _ = self.args.python_argdefs()
|
||||
argdefs = self.add_blockd_to_args(argdefs)
|
||||
code.splice(
|
||||
self.jit_line(
|
||||
heuristics,
|
||||
size_hints,
|
||||
selected_kernel,
|
||||
pointwise_with_reduce=pointwise_with_reduction,
|
||||
)
|
||||
)
|
||||
code.writeline(
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
|
||||
)
|
||||
|
||||
with code.indent():
|
||||
code.splice("pid = tl.program_id(0)")
|
||||
|
||||
for num, sub_kernel in enumerate(self.sub_kernels):
|
||||
self.codegen_pid_range(code, num)
|
||||
with code.indent():
|
||||
uniquify = self.codegen_static_numels_sub_kernel(
|
||||
code, sub_kernel, num
|
||||
)
|
||||
sub_kernel.codegen_body()
|
||||
uniquified_body = self.uniquify_block_sizes(
|
||||
sub_kernel.body, num, uniquify
|
||||
)
|
||||
code.splice(uniquified_body)
|
||||
|
||||
code.splice("else:")
|
||||
with code.indent():
|
||||
code.splice("pass")
|
||||
|
||||
if config.benchmark_combo_kernel:
|
||||
code.splice(self.codegen_kernel_benchmark(num_gb=0))
|
||||
|
||||
return code.getvalue()
|
||||
|
||||
def codegen_kernel_benchmark(self, num_gb, grid=None):
|
||||
result = IndentedBuffer()
|
||||
argdefs, call_args, signature = self.args.python_argdefs()
|
||||
|
||||
result.writelines(["", "", "def get_args():"])
|
||||
with result.indent():
|
||||
name_cnt = itertools.count()
|
||||
var_names = []
|
||||
for arg_name, arg_sig in zip(call_args, signature):
|
||||
var_name = f"arg_{next(name_cnt)}"
|
||||
buf = V.graph.get_buffer(arg_name)
|
||||
if buf:
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
)
|
||||
elif arg_name in V.graph.constants:
|
||||
# note that random seed is put in V.graph.constants
|
||||
const_tensor = V.graph.constants[arg_name]
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
||||
# Force the seed_offset to be 0 so calls to the same kernel
|
||||
# using different seed offset will have the same benchmark harness.
|
||||
# We can dedup kernel definitions in this case.
|
||||
if "seed_offset" in arg_sig.name:
|
||||
symval_hint = 0
|
||||
result.writeline(f"{var_name} = {symval_hint}")
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Don't find the buffer or const tensor for {arg_name}"
|
||||
)
|
||||
var_names.append(var_name)
|
||||
result.writeline(f"return {', '.join(var_names)},")
|
||||
|
||||
result.writelines(["\n", "\n", "def call(args):"])
|
||||
if grid is None:
|
||||
grid = self.grid(self.grids)
|
||||
grid_str = ", ".join(pexpr(item) for item in grid)
|
||||
grid_extra_kwargs = f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}"
|
||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||
grid_arg = f"grid=grid_combo_kernels({grid_str})"
|
||||
else:
|
||||
grid_arg = f"grid={grid}"
|
||||
index = V.graph.scheduler.current_device.index
|
||||
with result.indent():
|
||||
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
|
||||
with result.indent():
|
||||
result.writeline(
|
||||
V.graph.device_ops.set_device(index)
|
||||
) # no-op to ensure context
|
||||
stream_name = f"stream{index}"
|
||||
result.writeline(f"{stream_name} = get_raw_stream({index})")
|
||||
result.writeline(
|
||||
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
|
||||
)
|
||||
|
||||
# benchmark all configs
|
||||
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
|
||||
with result.indent():
|
||||
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
|
||||
with result.indent():
|
||||
result.writeline(
|
||||
V.graph.device_ops.set_device(index)
|
||||
) # no-op to ensure context
|
||||
result.writeline(
|
||||
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
|
||||
)
|
||||
|
||||
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
|
||||
with result.indent():
|
||||
result.writeline("from triton.testing import do_bench")
|
||||
result.writeline("")
|
||||
|
||||
result.writeline("args = get_args()")
|
||||
result.writeline(
|
||||
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
|
||||
)
|
||||
result.writeline(f"num_gb = {num_gb}")
|
||||
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
|
||||
result.writeline(
|
||||
'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def imports_for_benchmark_kernel(self):
|
||||
return textwrap.dedent(
|
||||
"""
|
||||
from torch._dynamo.testing import rand_strided
|
||||
{}
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
|
||||
""".format(
|
||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
)
|
||||
)
|
||||
|
||||
def uniquify_block_sizes(
|
||||
self, code: IndentedBuffer, num_kernel, uniquify: List[str]
|
||||
) -> IndentedBuffer:
|
||||
if not uniquify:
|
||||
return code
|
||||
modified = IndentedBuffer(initial_indent=code._indent)
|
||||
for line in code._lines:
|
||||
if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
|
||||
modified_line = line
|
||||
for block in blocks:
|
||||
modified_line = modified_line.replace(
|
||||
block, f"{block}_{num_kernel}"
|
||||
)
|
||||
modified.writeline(modified_line)
|
||||
elif isinstance(line, DeferredLine) and (
|
||||
blocks := [e for e in uniquify if e in line.line]
|
||||
):
|
||||
modified_line = line.line
|
||||
for block in blocks:
|
||||
modified_line = modified_line.replace(
|
||||
block, f"{block}_{num_kernel}"
|
||||
)
|
||||
new_line = DeferredLine(line.name, modified_line)
|
||||
modified.writeline(new_line)
|
||||
else:
|
||||
modified.writeline(line)
|
||||
return modified
|
||||
|
||||
def call_kernel(self, code, name: str):
|
||||
_, call_args, _ = self.args.python_argdefs()
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
for i in range(len(call_args)):
|
||||
if V.graph.is_unspec_arg(call_args[i]):
|
||||
call_args[i] = call_args[i] + ".item()"
|
||||
|
||||
wrapper = V.graph.wrapper_code
|
||||
grid = self.grid(self.grids)
|
||||
grid = wrapper.generate_default_grid(
|
||||
name,
|
||||
grid,
|
||||
grid_callable=grid_combo_kernels,
|
||||
num_kernels=len(self.sub_kernels),
|
||||
min_blocks=max(self.min_x_blocks_list) * len(self.sub_kernels),
|
||||
)
|
||||
wrapper.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
grid,
|
||||
V.graph.scheduler.current_device.index,
|
||||
cuda=True,
|
||||
triton=True,
|
||||
grid_fn="grid_combo_kernels",
|
||||
grid_extra_kwargs=f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}",
|
||||
)
|
||||
|
||||
def grid(self, sub_kernel_numels):
|
||||
xnumel = [e[-1] if len(e) > 0 else None for e in sub_kernel_numels]
|
||||
ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
|
||||
znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
|
||||
|
||||
# TODO: improve 1d/2d mixed cases
|
||||
xnumel = None if any(e is None for e in xnumel) else max(xnumel)
|
||||
ynumel = None if any(e is None for e in ynumel) else max(ynumel)
|
||||
znumel = None if any(e is None for e in znumel) else max(znumel)
|
||||
|
||||
numels = (
|
||||
(xnumel,)
|
||||
if not ynumel
|
||||
else (ynumel, xnumel)
|
||||
if not znumel
|
||||
else (znumel, ynumel, xnumel)
|
||||
)
|
||||
return numels
|
||||
@ -1,248 +0,0 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from sympy import Integer
|
||||
|
||||
from .. import metrics
|
||||
from ..runtime.hints import DeviceProperties
|
||||
from ..scheduler import SchedulerNode
|
||||
from ..utils import ceildiv, Placeholder
|
||||
from ..virtualized import V
|
||||
from .common import IndentedBuffer, Kernel
|
||||
from .triton import gen_common_triton_imports, TritonKernel
|
||||
from .triton_utils import config_of, signature_to_meta
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartitionState:
|
||||
partitions: List[
|
||||
List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
|
||||
]
|
||||
cur_partition: List[
|
||||
Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
|
||||
]
|
||||
cur_count: int
|
||||
|
||||
def finalize(self):
|
||||
if self.cur_partition:
|
||||
self.partitions.append(self.cur_partition)
|
||||
|
||||
|
||||
class ForeachKernel(Kernel):
|
||||
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
|
||||
|
||||
@staticmethod
|
||||
def _update_partition(partition_state, node_rw_count, node_info):
|
||||
if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
|
||||
partition_state.partitions.append(partition_state.cur_partition)
|
||||
partition_state.cur_partition = [node_info]
|
||||
partition_state.cur_count = node_rw_count
|
||||
else:
|
||||
partition_state.cur_count += node_rw_count
|
||||
partition_state.cur_partition.append(node_info)
|
||||
|
||||
@staticmethod
|
||||
def horizontal_partition(subkernel_nodes, triton_scheduling):
|
||||
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
|
||||
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
|
||||
(read/writes) and to have the same 2D or 1D blocking strategy."""
|
||||
assert len(subkernel_nodes) >= 1
|
||||
|
||||
partition_state_1d = PartitionState([], [], 0)
|
||||
yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
|
||||
lambda: PartitionState([], [], 0)
|
||||
)
|
||||
|
||||
for node in subkernel_nodes:
|
||||
fused_nodes = node.get_nodes()
|
||||
_, (numel, rnumel) = max(
|
||||
fused_nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
|
||||
node_info = fused_nodes, tiled_groups, numel, rnumel
|
||||
|
||||
read_writes = node.read_writes
|
||||
read_write_count = len(read_writes.reads) + len(read_writes.writes)
|
||||
|
||||
if tiled_groups[1] == 1:
|
||||
ForeachKernel._update_partition(
|
||||
partition_state_1d, read_write_count, node_info
|
||||
)
|
||||
else:
|
||||
y_elem = tiled_groups[0]
|
||||
partition_state_2d = yelem_to_partition_state_2d[y_elem]
|
||||
ForeachKernel._update_partition(
|
||||
partition_state_2d, read_write_count, node_info
|
||||
)
|
||||
|
||||
partition_state_1d.finalize()
|
||||
all_partitions = partition_state_1d.partitions
|
||||
for partition_state_2d in yelem_to_partition_state_2d.values():
|
||||
partition_state_2d.finalize()
|
||||
all_partitions.extend(partition_state_2d.partitions)
|
||||
|
||||
return all_partitions
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.blocking_2d = False
|
||||
self.block_size_1d = 1024 # Try tuning this value
|
||||
self.block_size_2d = 32
|
||||
self.num_warps = 8
|
||||
self.sub_kernels = []
|
||||
self.iter_vars_count = itertools.count()
|
||||
self.x_block_count = 0
|
||||
self.y_block_count = 0
|
||||
|
||||
def get_block_size(self):
|
||||
if self.blocking_2d:
|
||||
return self.block_size_2d
|
||||
else:
|
||||
return self.block_size_1d
|
||||
|
||||
@staticmethod
|
||||
def codegen_pid_offsets(code, block_count, lower_bound, prefix):
|
||||
if block_count == 0:
|
||||
code.splice(f"{prefix}pid_offset = {prefix}pid")
|
||||
else:
|
||||
code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")
|
||||
|
||||
def codegen_pid_range(self, code, x_elems):
|
||||
num_x_blocks = ceildiv(x_elems, self.get_block_size())
|
||||
upper_bound_x_pid = self.x_block_count + num_x_blocks
|
||||
lower_bound_x_pid = self.x_block_count
|
||||
|
||||
if self.x_block_count == 0:
|
||||
cond = "if"
|
||||
else:
|
||||
cond = "elif"
|
||||
|
||||
x_pid_bounds_check = (
|
||||
f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
|
||||
)
|
||||
code.splice(f"{cond} {x_pid_bounds_check}:")
|
||||
|
||||
with code.indent():
|
||||
ForeachKernel.codegen_pid_offsets(
|
||||
code, num_x_blocks, lower_bound_x_pid, "x"
|
||||
)
|
||||
self.x_block_count += num_x_blocks
|
||||
|
||||
def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
|
||||
sub_kernel = TritonKernel(
|
||||
*groups,
|
||||
index_dtype=index_dtype,
|
||||
mutations=mutations,
|
||||
pid_cache={
|
||||
"tl.program_id(0)": "xpid_offset",
|
||||
"tl.program_id(1)": "ypid",
|
||||
},
|
||||
reduction_hint=reduction_hint,
|
||||
)
|
||||
if self.blocking_2d:
|
||||
assert len(groups) == 3
|
||||
|
||||
self.blocking_2d |= groups[1] != 1 and len(groups) == 3
|
||||
metrics.generated_kernel_count -= 1
|
||||
sub_kernel.args = self.args
|
||||
sub_kernel.iter_vars_count = self.iter_vars_count
|
||||
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
|
||||
self.sub_kernels.append(sub_kernel)
|
||||
return sub_kernel
|
||||
|
||||
def jit_lines(self):
|
||||
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
|
||||
size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
|
||||
_, _, signature = self.args.python_argdefs()
|
||||
triton_meta = {
|
||||
"signature": signature_to_meta(signature, size_dtype=size_dtype),
|
||||
"device": DeviceProperties.create(V.graph.scheduler.current_device),
|
||||
"constants": {},
|
||||
}
|
||||
triton_meta["configs"] = [config_of(signature)]
|
||||
inductor_meta = {
|
||||
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
|
||||
**TritonKernel.inductor_meta_common(),
|
||||
}
|
||||
return f"""
|
||||
@triton_heuristics.foreach(
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
|
||||
def grid(self):
|
||||
return (
|
||||
self.x_block_count,
|
||||
ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
|
||||
if self.blocking_2d
|
||||
else 1,
|
||||
1,
|
||||
)
|
||||
|
||||
def codegen_kernel(self, name=None):
|
||||
code = IndentedBuffer()
|
||||
|
||||
code.splice(gen_common_triton_imports())
|
||||
argdefs, _, _ = self.args.python_argdefs()
|
||||
code.splice(self.jit_lines())
|
||||
code.writeline(
|
||||
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
|
||||
)
|
||||
|
||||
with code.indent():
|
||||
code.splice("xpid = tl.program_id(0)")
|
||||
if self.blocking_2d:
|
||||
code.splice("ypid = tl.program_id(1)")
|
||||
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
|
||||
code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
|
||||
else:
|
||||
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
|
||||
|
||||
for sub_kernel in self.sub_kernels:
|
||||
assert len(sub_kernel.numels) <= 3
|
||||
# TODO mlazos: support dynamic shapes
|
||||
numel_ind = 0 if not self.blocking_2d else 1
|
||||
self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
|
||||
with code.indent():
|
||||
if self.blocking_2d:
|
||||
code.splice(f"ynumel = {sub_kernel.numels[0]}")
|
||||
code.splice(f"xnumel = {sub_kernel.numels[1]}")
|
||||
else:
|
||||
code.splice(f"xnumel = {sub_kernel.numels[0]}")
|
||||
|
||||
sub_kernel.codegen_body()
|
||||
code.splice(sub_kernel.body)
|
||||
|
||||
code.splice("else:")
|
||||
with code.indent():
|
||||
code.splice("pass")
|
||||
|
||||
return code.getvalue()
|
||||
|
||||
def call_kernel(self, code, name: str):
|
||||
_, call_args, _ = self.args.python_argdefs()
|
||||
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
||||
for i in range(len(call_args)):
|
||||
if V.graph.is_unspec_arg(call_args[i]):
|
||||
call_args[i] = call_args[i] + ".item()"
|
||||
if V.graph.cpp_wrapper:
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
name,
|
||||
call_args,
|
||||
device_index=V.graph.scheduler.current_device.index,
|
||||
grid=self.grid(),
|
||||
)
|
||||
else:
|
||||
# TODO: refactor generate_kernel_call
|
||||
call_args_str = ", ".join(call_args)
|
||||
stream_name = code.write_get_raw_stream(
|
||||
V.graph.scheduler.current_device.index
|
||||
)
|
||||
code.writeline(
|
||||
f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
|
||||
)
|
||||
@ -529,7 +529,7 @@ class WrapperCodeGen(CodeGen):
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from {} import grid, split_scan_grid, start_graph, end_graph
|
||||
from {} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
|
||||
{}
|
||||
""".format(
|
||||
triton_heuristics.__name__,
|
||||
@ -1361,8 +1361,15 @@ class WrapperCodeGen(CodeGen):
|
||||
"""
|
||||
)
|
||||
|
||||
def generate_default_grid(self, name: str, grid_args: List[Any]):
|
||||
return grid_args
|
||||
def generate_default_grid(
|
||||
self,
|
||||
name: str,
|
||||
grid: List[Any],
|
||||
cuda: bool = True,
|
||||
grid_callable: Optional[Callable[..., Any]] = None,
|
||||
**grid_extra_kwags,
|
||||
):
|
||||
return grid
|
||||
|
||||
def generate_kernel_call(
|
||||
self,
|
||||
@ -1375,6 +1382,7 @@ class WrapperCodeGen(CodeGen):
|
||||
arg_types=None,
|
||||
grid_fn: str = "grid",
|
||||
triton_meta=None,
|
||||
grid_extra_kwargs="",
|
||||
):
|
||||
"""
|
||||
Generates kernel call code.
|
||||
@ -1392,6 +1400,8 @@ class WrapperCodeGen(CodeGen):
|
||||
)
|
||||
if triton:
|
||||
grid_str = ", ".join(pexpr(item) for item in grid)
|
||||
if grid_extra_kwargs:
|
||||
grid_str = f"{grid_str}, {grid_extra_kwargs}"
|
||||
grid_str = f"{grid_fn}({grid_str})"
|
||||
self.writeline(
|
||||
f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
|
||||
|
||||
@ -362,6 +362,12 @@ assert_indirect_indexing = True
|
||||
# compute CSE bounds on variables that do not appear in the FX graph
|
||||
compute_all_bounds = False
|
||||
|
||||
# enable the combo kernel that combines data-independent kernels (additional
|
||||
# to foreach kernels) into a single one (Experimental)
|
||||
combo_kernels = False
|
||||
# benchmark combo kernels and only allow ones with perf gains
|
||||
benchmark_combo_kernel = False
|
||||
|
||||
# constant folding on the joint graph
|
||||
joint_graph_constant_folding = True
|
||||
|
||||
|
||||
@ -1741,3 +1741,15 @@ def split_scan_grid(xnumel, rnumel):
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
|
||||
return grid_fn
|
||||
|
||||
|
||||
def grid_combo_kernels(*numels, num_kernels, min_blocks):
|
||||
"""min_blocks is the minimal size of the grid x dimension"""
|
||||
kernel_grid_fn = grid(*numels)
|
||||
|
||||
def grid_fn(meta):
|
||||
cuda_grid = list(kernel_grid_fn(meta))
|
||||
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
|
||||
return tuple(cuda_grid)
|
||||
|
||||
return grid_fn
|
||||
|
||||
@ -1025,8 +1025,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
|
||||
class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
"""Scheduler node which consists of a list of scheduler nodes that each operate on a
|
||||
distinct tensor in a list of tensors."""
|
||||
"""
|
||||
This is a schedular node that consists of a set of scheduler nodes that
|
||||
has no data dependencies among them and can be executed in parallel.
|
||||
"""
|
||||
|
||||
def get_consumer_subnode_for(self, producer):
|
||||
if producer.get_name() in self.read_to_node:
|
||||
@ -1075,6 +1077,11 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
@classmethod
|
||||
def fuse(cls, producer, consumer):
|
||||
assert producer.is_foreach() or consumer.is_foreach()
|
||||
use_custom_partition_algo = (
|
||||
producer.use_custom_partition_algo
|
||||
if producer.is_foreach()
|
||||
else consumer.use_custom_partition_algo
|
||||
)
|
||||
prev_node_1 = None
|
||||
prev_node_2 = None
|
||||
if producer.is_foreach() and consumer.is_foreach():
|
||||
@ -1108,13 +1115,24 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
fused_nodes.append(new_node)
|
||||
else:
|
||||
fused_nodes.append(node)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
|
||||
)
|
||||
|
||||
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
|
||||
return cls(
|
||||
producer.scheduler,
|
||||
fused_nodes,
|
||||
use_custom_partition_algo=use_custom_partition_algo,
|
||||
prev_node_1=prev_node_1,
|
||||
prev_node_2=prev_node_2,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: "Scheduler",
|
||||
nodes: List[SchedulerNode],
|
||||
snodes: List[SchedulerNode],
|
||||
use_custom_partition_algo: bool,
|
||||
prev_node_1=None,
|
||||
prev_node_2=None,
|
||||
):
|
||||
@ -1122,9 +1140,9 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
self.name_to_node = {}
|
||||
|
||||
if prev_node_1 is None or prev_node_2 is None:
|
||||
super().__init__(scheduler, nodes)
|
||||
super().__init__(scheduler, snodes)
|
||||
|
||||
for node in nodes:
|
||||
for node in snodes:
|
||||
for read in node.read_writes.reads:
|
||||
self.read_to_node[read.name] = node
|
||||
|
||||
@ -1132,7 +1150,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
self.name_to_node[name] = node
|
||||
else:
|
||||
self.scheduler = scheduler
|
||||
self.snodes = nodes
|
||||
self.snodes = snodes
|
||||
self.node: ir.Buffer = None # type: ignore[assignment]
|
||||
self.users: List[NodeUser] = []
|
||||
|
||||
@ -1163,17 +1181,43 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
for name in other_node.get_names():
|
||||
self.name_to_node[name] = other_node
|
||||
|
||||
self.group = (nodes[0].get_device(), "foreach")
|
||||
|
||||
self.use_custom_partition_algo = use_custom_partition_algo
|
||||
self.group = (snodes[0].get_device(), "combo_kernel")
|
||||
self.origins: Set[torch.fx.Node] = set()
|
||||
|
||||
@classmethod
|
||||
def combinable_nodes(cls, nodes: List[SchedulerNode]) -> List[SchedulerNode]:
|
||||
extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
|
||||
if extern:
|
||||
log.debug(
|
||||
"ComboKernels: %d external nodes are filtered %s",
|
||||
len(extern),
|
||||
[node.node.origins for node in extern],
|
||||
)
|
||||
filtered_nodes = [
|
||||
x
|
||||
for x in nodes
|
||||
if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode))
|
||||
]
|
||||
foreach_nodes = [
|
||||
x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
|
||||
]
|
||||
if foreach_nodes:
|
||||
log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
|
||||
filtered_nodes = [
|
||||
x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
|
||||
]
|
||||
template_nodes = [x for x in filtered_nodes if x.is_template()]
|
||||
if template_nodes:
|
||||
log.debug(
|
||||
"ComboKernels: %d template nodes are filtered", {len(template_nodes)}
|
||||
)
|
||||
filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
|
||||
return filtered_nodes
|
||||
|
||||
def mark_run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def codegen(self):
|
||||
assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
|
||||
self.node.get_store_function()(self.node.make_loader()())
|
||||
|
||||
def can_free(self):
|
||||
return NotImplementedError
|
||||
|
||||
@ -1181,7 +1225,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
return True
|
||||
|
||||
def get_subkernel_nodes(self):
|
||||
"""Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists.
|
||||
"""Returns a list of nodes which comprise the combo kernel.
|
||||
These nodes may be vertically fused."""
|
||||
return list(self.snodes)
|
||||
|
||||
@ -1339,6 +1383,8 @@ class Scheduler:
|
||||
# Refresh node_users and inverse_users to reflect fused nodes
|
||||
self.compute_node_users()
|
||||
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
|
||||
if config.combo_kernels:
|
||||
self.create_combo_kernel_nodes(num_ck_nodes=None)
|
||||
self.compute_last_usage()
|
||||
V.debug.ir_post_fusion(self.nodes)
|
||||
V.debug.graph_diagram(self.nodes)
|
||||
@ -1405,7 +1451,7 @@ class Scheduler:
|
||||
removed_node_names.update(names)
|
||||
snodes = [self.name_to_node[name] for name in names]
|
||||
|
||||
fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type]
|
||||
fe_node = ForeachKernelSchedulerNode(self, snodes, use_custom_partition_algo=False) # type: ignore[arg-type]
|
||||
|
||||
fe_nodes.append(fe_node)
|
||||
|
||||
@ -1694,6 +1740,51 @@ class Scheduler:
|
||||
visit(node)
|
||||
self.nodes = result
|
||||
|
||||
def _get_unmet_dep_nodes(self, snode):
|
||||
unmet_deps = set()
|
||||
if isinstance(
|
||||
snode,
|
||||
(
|
||||
SchedulerNode,
|
||||
ExternKernelSchedulerNode,
|
||||
NopKernelSchedulerNode,
|
||||
FusedSchedulerNode,
|
||||
),
|
||||
):
|
||||
for dep in snode.unmet_dependencies:
|
||||
unmet_deps.add(dep.name)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"get_unmet_dep_nodes is not implemented for {type(snode)}."
|
||||
)
|
||||
return list({self.name_to_fused_node[n] for n in unmet_deps})
|
||||
|
||||
def _topological_sort_nodes(self):
|
||||
"""
|
||||
Sort nodes by their topological order, return a list of node lists.
|
||||
"""
|
||||
order = []
|
||||
nodes = {n: 0 for n in self.nodes}
|
||||
children: Dict[Any, Any] = {}
|
||||
for node in self.nodes:
|
||||
deps = self._get_unmet_dep_nodes(node)
|
||||
nodes[node] = len(deps)
|
||||
for dep in deps:
|
||||
c = children.get(dep, [])
|
||||
c.append(node)
|
||||
children[dep] = c
|
||||
|
||||
zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
|
||||
while zero_deg_nodes:
|
||||
order.append(zero_deg_nodes)
|
||||
for n in zero_deg_nodes:
|
||||
for user in children.get(n, []):
|
||||
nodes[user] -= 1
|
||||
nodes.pop(n)
|
||||
zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
|
||||
assert not nodes, "Topological sort failed!"
|
||||
return order
|
||||
|
||||
def compute_ancestors(self):
|
||||
"""
|
||||
Populate each node.ancestors
|
||||
@ -1975,6 +2066,65 @@ class Scheduler:
|
||||
self.topological_sort_schedule()
|
||||
self.prune_redundant_deps()
|
||||
|
||||
def create_combo_kernel_nodes(self, num_ck_nodes=None):
|
||||
"""
|
||||
Groups parallel nodes
|
||||
"""
|
||||
fused_nodes = set(self.nodes)
|
||||
count = 0
|
||||
num_nodes_orig = len(self.nodes)
|
||||
log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes)
|
||||
for num, node_list in enumerate(self._group_nodes_for_combo_kernels()):
|
||||
node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
|
||||
if len(node_list) < 2:
|
||||
continue
|
||||
if num_ck_nodes is not None and count > num_ck_nodes:
|
||||
break
|
||||
if not self.speedup_by_combo_kernel(node_list):
|
||||
log.debug("ComboKernels: Not speeding up %d-th group", num)
|
||||
continue
|
||||
count += 1
|
||||
group_snode = ForeachKernelSchedulerNode(
|
||||
node_list[0].scheduler, node_list, use_custom_partition_algo=True
|
||||
)
|
||||
log.info(
|
||||
"ComboKernels: Combining %d nodes for %d-th group",
|
||||
len(node_list),
|
||||
num,
|
||||
)
|
||||
for node in node_list:
|
||||
fused_nodes.remove(node)
|
||||
fused_nodes.add(group_snode)
|
||||
self.name_to_fused_node.update(
|
||||
{n.get_name(): group_snode for n in group_snode.get_nodes()}
|
||||
)
|
||||
self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
|
||||
self.topological_sort_schedule()
|
||||
log.info(
|
||||
"Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels",
|
||||
count,
|
||||
num_nodes_orig,
|
||||
len(self.nodes),
|
||||
)
|
||||
self.prune_redundant_deps()
|
||||
|
||||
def _group_nodes_for_combo_kernels(self):
|
||||
"""
|
||||
Returns a list of lists of nodes that are to be grouped together.
|
||||
"""
|
||||
sorted_nodes = self._topological_sort_nodes()
|
||||
grouped_nodes = []
|
||||
max_num_nodes = 8
|
||||
for nodes in sorted_nodes:
|
||||
grouped_nodes.extend(
|
||||
[
|
||||
nodes[i : i + max_num_nodes]
|
||||
for i in range(0, len(nodes), max_num_nodes)
|
||||
]
|
||||
)
|
||||
|
||||
return grouped_nodes
|
||||
|
||||
def prune_redundant_deps(self):
|
||||
for node in self.nodes:
|
||||
node.prune_redundant_deps(self.name_to_fused_node)
|
||||
@ -2583,7 +2733,7 @@ class Scheduler:
|
||||
elif node.is_extern():
|
||||
self.codegen_extern_call(node)
|
||||
elif node.is_foreach():
|
||||
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
|
||||
self.get_backend(device).codegen_combo_kernel(node) # type: ignore[possibly-undefined]
|
||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||
self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
@ -2614,6 +2764,98 @@ class Scheduler:
|
||||
node = self.name_to_node[buf_name]
|
||||
return node.node.get_layout()
|
||||
|
||||
def benchmark_combo_kernel(self, node_list):
|
||||
"""
|
||||
Benchmark fused list of nodes and return the execution time
|
||||
in milliseconds on randomly generated inputs.
|
||||
"""
|
||||
device = node_list[0].get_device()
|
||||
V.graph.scheduler = self
|
||||
self.current_device = device
|
||||
backend = self.get_backend(device)
|
||||
return backend.benchmark_combo_kernel(node_list)
|
||||
|
||||
def speedup_by_combo_kernel(self, node_list):
|
||||
"""
|
||||
If config.benchmark_fusion is False, always return True.
|
||||
Otherwise, return True if fusion can brings speedup.
|
||||
"""
|
||||
if not config.benchmark_combo_kernel:
|
||||
return True
|
||||
|
||||
subkernel_nodes = node_list
|
||||
device = subkernel_nodes[0].get_device()
|
||||
|
||||
# don't support benchmark fusion for CPU right now.
|
||||
if device.type == "cpu":
|
||||
return True
|
||||
|
||||
from triton.compiler.errors import CompilationError
|
||||
|
||||
ms1, path1_list = 0.0, []
|
||||
for i, snode in enumerate(subkernel_nodes):
|
||||
node_list = snode.get_nodes()
|
||||
# We can not accurately benchmark kernel using atomic_add
|
||||
# due to how we generate random integer inputs.
|
||||
# Skip benchmarking them by allowing fusion.
|
||||
if any(
|
||||
hasattr(n.node, "data")
|
||||
and hasattr(n.node.data, "scatter_mode")
|
||||
and n.node.data.scatter_mode == "atomic_add"
|
||||
for n in node_list
|
||||
):
|
||||
fusion_log.debug(
|
||||
"ComboKernel: benchmarking may not accurate due to atomic_add"
|
||||
)
|
||||
|
||||
try:
|
||||
ms, path = self.benchmark_fused_nodes(node_list)
|
||||
if math.isinf(ms):
|
||||
fusion_log.debug(
|
||||
"ComboKernel benchmark: register spilling of %d-th subkernel",
|
||||
i,
|
||||
)
|
||||
return False
|
||||
except CompilationError as e:
|
||||
# workaround triton issue: https://github.com/openai/triton/issues/2151
|
||||
if "Loop-carried variable" in str(e):
|
||||
fusion_log.debug(
|
||||
"ComboKernel benchmark: return True because of loop-carried variable"
|
||||
)
|
||||
return True # allow fusion
|
||||
else:
|
||||
raise
|
||||
ms1 += ms
|
||||
path1_list.append(path)
|
||||
|
||||
try:
|
||||
ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes)
|
||||
except CompilationError as e:
|
||||
# workaround triton issue: https://github.com/openai/triton/issues/2151
|
||||
if "Loop-carried variable" in str(e):
|
||||
fusion_log.debug(
|
||||
"ComboKernel benchmark: return True because of loop-carried variable"
|
||||
)
|
||||
return True # allow fusion
|
||||
else:
|
||||
raise
|
||||
|
||||
# small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
|
||||
small_kernel = ms2_clone / ms2 > 0.6 and ms2 - ms2_clone < 0.2
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
if ms1 > ms2 or small_kernel:
|
||||
fusion_log.debug(
|
||||
"can fuse (benchmark): fusing causes %sx speedup",
|
||||
green_text(f"{ms1 / ms2:.3f}"),
|
||||
)
|
||||
else:
|
||||
fusion_log.debug(
|
||||
"cannot fuse (benchmark): fusing causes %sx slowdown",
|
||||
red_text(f"{ms1 / ms2:.3f}"),
|
||||
)
|
||||
|
||||
return ms2 < ms1 or small_kernel
|
||||
|
||||
|
||||
class BaseScheduling:
|
||||
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
||||
|
||||
Reference in New Issue
Block a user