[Tiling rewrite pt1] Normalize reads and writes to common iter space (#153723)

In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153723
Approved by: https://github.com/jansel
This commit is contained in:
eellison
2025-06-02 16:44:56 -07:00
committed by PyTorch MergeBot
parent 635b73e697
commit 00dfd3891e
8 changed files with 659 additions and 13 deletions

View File

@ -957,6 +957,7 @@ exclusions = {
"cudagraph_static_inputs",
"benchmarking",
"loop_ordering",
"loop_tiling",
"autotuning",
"graph_region_expansion",
}

View File

@ -18,11 +18,16 @@ from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.test_operators import realize
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.utils._pytree import tree_map
from torch.utils._sympy.functions import ModularIndexing
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
# set so that metrics appear
@ -41,9 +46,11 @@ class MockScheduler:
def get_backend(cls, *args):
return TritonScheduling(cls)
def can_buffer_be_removed_through_fusion(self, *args, **kwargs):
return False
@inductor_config.patch(loop_ordering_after_fusion=True)
class ImplDetailTest(TestCase):
class MockSchedulerTest(TestCase):
_exit_stack = None
@classmethod
@ -61,6 +68,9 @@ class ImplDetailTest(TestCase):
super().tearDownClass()
cls._exit_stack.close()
@inductor_config.patch(loop_ordering_after_fusion=True)
class ImplDetailTest(MockSchedulerTest):
@staticmethod
def _get_snode_body_sym_prefix(snode):
body = snode._body
@ -509,6 +519,231 @@ class LoopOrderingTest(TestCase):
print(f"{ms=:.3f}")
@inductor_config.patch(
{
"benchmark_kernel": True,
"loop_ordering_after_fusion": True,
"triton.unique_kernel_names": True,
}
)
@instantiate_parametrized_tests
class MemoryCoalescingTest(MockSchedulerTest):
"""Tests for memory coalescing analysis with specific tensor sizes."""
device = GPU_TYPE
_exit_stack = None
def setUp(self):
super().setUp()
metrics.reset()
def _create_buffer(self, name, sizes):
"""Create a buffer with specified sizes"""
strides = ir.FlexibleLayout.contiguous_strides(sizes)
box = ir.TensorBox.create(
ir.Buffer(
name=name,
layout=ir.FixedLayout(
torch.device(self.device),
dtype=torch.float32,
size=sizes,
stride=strides,
),
)
)
box_loader = box.make_loader()
def inner_fn(index):
return box_loader(index) * 2
buf = ir.Pointwise.create(
device=box.get_device(),
dtype=box.get_dtype(),
inner_fn=inner_fn,
ranges=box.get_size(),
)
buf.realize()
computed_buf = buf.data.data
computed_buf.decide_layout()
return computed_buf
def _create_scheduler_node(self, buf):
s = SchedulerNode(V.graph.scheduler, buf)
s.min_order = 0
s.max_order = 100
return s
@parametrize(
"inps",
(
((128, 384, 196), (768, 64, 196), (128, 6, 64, 196)),
((64,), (16, 4), (16, 4)),
((5, 6), (3, 10), (30,)),
((5, 6, 20), (3, 10, 20), (30, 20)),
),
)
def test_inferred_splits(self, inps):
"""
Test memory coalescing analysis with the specified tensor sizes.
Using direct SchedulerNode creation with sizes (128, 384, 196) and (768, 64, 196).
"""
s1, s2, expected_size = inps
# Create buffers with the specified sizes
buf1 = self._create_buffer("buffer1", s1)
buf2 = self._create_buffer("buffer2", s2)
# Create scheduler nodes
snode1 = self._create_scheduler_node(buf1)
snode2 = self._create_scheduler_node(buf2)
# Create a fused node
fused_node = torch._inductor.scheduler.FusedSchedulerNode.fuse(snode1, snode2)
from torch._inductor import tiling_utils
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(fused_node)
var_ranges = fused_norm_read_writes.var_ranges
self.assertEqual(list(var_ranges.values()), list(expected_size))
def test_remapped_reads(self):
from torch._inductor import tiling_utils
def fn(nodes):
assert len(nodes) == 1
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(
nodes[0]
)
self.assertTrue(len(fused_norm_read_writes.var_ranges) == 2)
# both reads remapped correctly
FileCheck().check("4*n0 + n1").run(
repr(fused_norm_read_writes.reads.keys())
)
FileCheck().check("n0 + 4*n1").run(
repr(fused_norm_read_writes.reads.keys())
)
return nodes
with torch._inductor.config.patch(_post_fusion_custom_pass=fn):
@torch.compile()
def foo(x, y):
return x + y
foo(torch.rand([4, 4], device="cuda"), torch.rand([4, 4], device="cuda").T)
def test_remapped_reads_split(self):
from torch._inductor import tiling_utils
def fn(nodes):
self.assertTrue(len(nodes) == 1)
fused_norm_read_writes = tiling_utils.extract_normalized_read_writes(
nodes[0]
)
inp_node_reads = nodes[0].get_nodes()[1]._body.get_read_exprs()
node_ranges = nodes[0].get_nodes()[1]._body.var_ranges
self.assertTrue(len(node_ranges) == 1)
self.assertTrue(next(iter(node_ranges.values())) == 36)
var = next(iter(node_ranges.keys()))
r = FloorDiv(var, 6) + 6 * ModularIndexing(var, 1, 6)
self.assertTrue(r in inp_node_reads)
# mapped reads
self.assertTrue(list(fused_norm_read_writes.var_ranges.values()) == [6, 6])
n0, n1 = list(fused_norm_read_writes.var_ranges.keys())
# translation of above is n0 + 6 * n1
self.assertTrue((n0 + 6 * n1) in fused_norm_read_writes.reads.keys())
return nodes
with torch._inductor.config.patch(_post_fusion_custom_pass=fn):
@torch.compile()
def foo(x, y):
return (
x + y
).contiguous().flatten() + torch.ops._inductor_test.realize(
(y.T + 1).flatten()
)
foo(torch.rand([6, 6], device="cuda"), torch.rand([6, 6], device="cuda").T)
def test_reduction_pointwise(self):
# test one pw var, one red var
from torch._inductor import tiling_utils
def fn(nodes):
self.assertTrue(len(nodes) == 1)
fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0])
i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars
self.assertTrue(len(i_vars) == 1)
self.assertTrue(len(r_vars) == 1)
# single write to index var
self.assertTrue(
fused_rw.index_vars[0] == next(iter(fused_rw.writes.keys()))
)
# the write to the fused intermediary node should be removed
self.assertTrue(len(fused_rw.writes) == 1)
# single read
self.assertTrue(len(fused_rw.reads) == 1)
# that is applied to two bufs
self.assertTrue(len(next(iter(fused_rw.reads.values()))) == 2)
# and the read should be in terms of the index + reduce var,
# even though node is pointwise
self.assertTrue(256 * i_vars[0] + r_vars[0] in fused_rw.reads)
return nodes
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
@torch.compile()
def foo(x, y):
out = torch.ops._inductor_test.realize(x + y)
return out.sum(dim=1)
foo(
torch.rand(256, 256, device="cuda"), torch.rand(256, 256, device="cuda")
)
def test_reduction_no_pointwise(self):
# test one pw var, one red var
from torch._inductor import tiling_utils
def fn(nodes):
self.assertTrue(len(nodes) == 1)
fused_rw = tiling_utils.extract_normalized_read_writes(nodes[0])
i_vars, r_vars = fused_rw.index_vars, fused_rw.reduce_vars
self.assertTrue(len(i_vars) == 0)
self.assertTrue(len(r_vars) == 1)
return nodes
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
@torch.compile()
def foo(x):
return x.sum()
foo(torch.rand(1024, device="cuda"))
if __name__ == "__main__":
if HAS_GPU:
run_tests()

View File

@ -705,6 +705,26 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
return new_ranges, return_getters_groups
@classmethod
def prepare_split_iteration_lengths(
cls,
groups: Iterable[sympy.Expr],
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One,
) -> Sequence[Sequence[sympy.Expr]]:
"Fill in the reduction numel of lengths if missing"
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
not sizevars.statically_known_equals(reduction_numel, sympy.S.One)
and sizevars.statically_known_equals(
sympy_product(groups),
sympy_product(lengths[0]) * reduction_numel,
)
):
return (lengths[0], [reduction_numel])
return lengths
@classmethod
def is_compatible(
cls,
@ -712,15 +732,7 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One,
) -> bool:
# Fill in the reduction numel, in case the node is missing it.
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
sizevars.statically_known_equals(
sympy_product(groups),
sympy_product(lengths[0]) * reduction_numel,
)
):
lengths = (lengths[0], [reduction_numel])
lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel)
try:
cls._split_iteration_ranges(groups, lengths)

View File

@ -271,6 +271,16 @@ _pre_fusion_custom_pass: Optional[
]
] = None
# Registers a custom pass to be run right after fusion in Inductor scheduler.
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
# hence custom IR passes built on top of it might break in the future.
_post_fusion_custom_pass: Optional[
Callable[
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
list["torch._inductor.scheduler.BaseSchedulerNode"],
]
] = None
# Deprecated
split_cat_fx_passes = True

View File

@ -312,6 +312,14 @@ class LoopBody:
for entry in self.memory_usage[MemoryUsageType.LOAD]
]
def get_all_read_expr(self, buffer_name):
# reversed to match old behavior
out = []
for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]):
if entry.buffer_name == buffer_name:
out.append(self.indexing_exprs[entry.index_name])
return out
def get_write_exprs(self):
return [
self.indexing_exprs[entry.index_name]
@ -321,6 +329,16 @@ class LoopBody:
)
]
def get_all_write_expr(self, buffer_name):
out = []
for entry in itertools.chain(
self.memory_usage[MemoryUsageType.STORE],
self.memory_usage[MemoryUsageType.STORE_REDUCTION],
):
if entry.buffer_name == buffer_name:
out.append(self.indexing_exprs[entry.index_name])
return out
def debug_str(self):
lines = [f"var_ranges = {dict(self.var_ranges)}"]
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])

View File

@ -2101,6 +2101,8 @@ class Scheduler:
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
self.nodes = self.fuse_nodes(self.nodes)
if config._post_fusion_custom_pass is not None:
self.nodes = config._post_fusion_custom_pass(self.nodes)
self.merge_loops()
self.finalize_multi_template_buffers()
if config.combo_kernels:

View File

@ -0,0 +1,362 @@
import dataclasses
import functools
import itertools
import sys
from collections import defaultdict
from collections.abc import Iterable, Iterator
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
import sympy
import torch
from torch._inductor.dependencies import index_vars_no_squeeze
from torch._inductor.utils import sympy_product, sympy_subs
from torch.utils._ordered_set import OrderedSet
from .virtualized import V
T = TypeVar("T")
U = TypeVar("U")
Split = tuple[sympy.Expr, ...]
loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling")
if TYPE_CHECKING:
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
@dataclasses.dataclass(frozen=True)
class FusedNormalizedReadsWrites:
"""
Normalized reads and writes for nodes in the same FusedSchedulerNode.
"""
index_vars: OrderedSet[sympy.Symbol]
reduce_vars: OrderedSet[sympy.Symbol]
reads: dict[sympy.Expr, OrderedSet[str]]
writes: dict[sympy.Expr, OrderedSet[str]]
var_ranges: dict[sympy.Symbol, int]
def get_pw_red_splits(
n: "SchedulerNode", pointwise_numel: sympy.Expr, red_numel: sympy.Expr
) -> tuple[tuple[list[sympy.Symbol], list[int]], tuple[list[sympy.Symbol], list[int]]]:
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
return (
(n._body.iter_vars, n._body.sizes[0]),
(n._body.reduce_vars, n._body.sizes[1]),
) # type: ignore[return-value]
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
i = len(n._body.sizes[0]) - 1
prod = 1
while i >= 0:
prod *= n._body.sizes[0][i]
if prod == red_numel:
break
if i >= 0:
pw_splits = n._body.sizes[0][0:i]
iter_vars = n._body.iter_vars[0:i]
red_splits = n._body.sizes[0][i:]
red_vars = n._body.iter_vars[i:]
return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value]
# TODO - handle, not sure if possible
raise RuntimeError(
f"Unhandled node: size: {n._body.sizes}, pw: {pointwise_numel}, red: {red_numel}"
)
class NodeSplitGetter:
"""
Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode.
"""
def __init__(
self,
node: Union["FusedSchedulerNode", "SchedulerNode"],
):
self.node = node
self.pointwise_numel: sympy.Expr = node.group[1][0]
self.red_numel: sympy.Expr = node.group[1][1]
self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet)
self.reduction_split: Split = ()
self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet()
fused_group = node.group[1]
for n in reversed(node.get_nodes()):
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
continue
(_, n_pw_splits), (_, n_red_splits) = get_pw_red_splits(
n, self.pointwise_numel, self.red_numel
)
# fill in reduction size
n_pw_splits, n_red_splits = (
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
fused_group, (n_pw_splits, n_red_splits), self.red_numel
)
)
self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits))
# initially, we are just going to do a single reduction split since
# reduction tiling is off by default. even if we miss a reduction split,
# we can recover it in the split var analysis.
# TODO: an earlier version fo this code tried to iteratively try the maximum number
# of split vars, by iterating over both pointwise and reduction. but not worth
# the complexity yet.
if n_red_splits != ():
self.reduction_split = (sympy_product(n_red_splits),)
n_size = (tuple(n_pw_splits), tuple(n_red_splits))
self.all_node_sizes.add(n_size)
self.seen_pw_splits: OrderedSet[Split] = OrderedSet()
def get_node_splits(self) -> tuple[Split, Split]:
"""
Get a compatible pointwise, reduction split of the node
"""
if len(self.all_node_sizes) == 1:
return next(iter(self.all_node_sizes))
max_pw_split = max(self.pw_split_options.keys())
for pw_split_len in range(max_pw_split, 0, -1):
for pw_split in self.pw_split_options[pw_split_len]:
if out := self.try_split(pw_split, self.reduction_split):
return out
# combine dims for next round
for pw_split in self.pw_split_options[pw_split_len]:
for i in range(len(pw_split) - 1):
new_split = tuple(
pw_split[0:i]
+ (sympy_product(pw_split[i : i + 2]),)
+ pw_split[i + 2 :]
)
self.pw_split_options[len(new_split)].add(new_split)
# if for whatever reason we couldnt split above, return default split
return ((self.pointwise_numel,), (self.red_numel,))
def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]:
"""
See if this split is compatible, and potentially returning a longer split
than the input.
"""
from torch._inductor.codegen.simd import CantSplit, SIMDKernel
if pw in self.seen_pw_splits:
return None
self.seen_pw_splits.add(pw)
for n_pw, n_red in self.all_node_sizes:
try:
groups = pw + red
lengths = (n_pw, n_red)
splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths)
except CantSplit:
return None
assert len(getters) == 2
pw_group_splits = splits[: len(pw)]
# if we had to divide a variable into two to do this split,
# then lets try the larger, induced split.
# e.g. splitting (12, 2) into (2, 12) will split the first var into:
# (2, 6) and produce an overall split of (2, 6, 2)
flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits))
if flattened_pw_splits != pw:
if out := self.try_split(flattened_pw_splits, red):
return out
return pw, red
if sys.version_info >= (3, 10):
# On Python 3.10+ we can use zip(strict=True)
zip_equal = functools.partial(zip, strict=True)
else:
# Fallback for older versions
def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]:
"""
Zip two iterables, raising ValueError if their lengths differ.
"""
if len(it1) != len(it2):
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
return zip(it1, it2)
def apply_var_mapping(
iter_vars: list[sympy.Symbol],
red_vars: list[sympy.Symbol],
norm_pw_vars: list[sympy.Symbol],
norm_red_vars: list[sympy.Symbol],
new_ranges: list[list[sympy.Expr]],
return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]],
) -> dict[sympy.Symbol, sympy.Expr]:
"""Maps original variables to expressions using normalized variables."""
# the output of split_iteration_range is a new_ranges, return_getters_groups
# new_ranges is a flattened list of ranges corresponding to the new pw and red vars
# for example, taking in pw vars of range (6, 6) to normalized range [36],
# new_ranges would be [[6, 6]]
# There is a return_getter callable for each input iter_var and red_vars.
# if you flatten out all of the ranges, and create a variable for each index,
# then applying the flattening vars to the callables in return_getters_groups
# gives you the mapping from input vars -> flattened vars.
# From there, we can compute the output, normalized variables.
# For instance [6, 6] corresponding to flat vars v0, v1 will be
# v0 + 6 * v1
# Create flattened iteration variables
num_vars = sum(len(s) for s in new_ranges)
flat_vars = sympy.symbols(f"v_0:{num_vars}")
count = 0
if len(iter_vars) == 0 and len(red_vars) == 0:
return {}
assert len(new_ranges) == len(norm_pw_vars + norm_red_vars)
apply_groups = []
for group in return_getters_groups:
apply_groups.append([g(flat_vars) for g in group])
iter_vars_to_flat_vars = {}
for i, (group, var_group) in enumerate(
zip_equal(apply_groups, ((iter_vars, red_vars)))
):
# if the node has sizes (p0, 1) and the fused node is (p0, r0)
# the reduction var gets filled in for split_iteration_range
if len(group) != len(var_group):
assert i == 1
assert len(var_group) == 0
continue
iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)})
count = 0
flat_vars_to_new_vars = {}
for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars):
range_vars = []
for i in range(len(new_range)):
range_vars.append(flat_vars[count])
count += 1
prod = 1
for i in range(len(new_range) - 1, -1, -1):
flat_vars_to_new_vars[range_vars[i]] = new_var * prod
prod = new_range[i] * prod
return {
k: sympy_subs(v, flat_vars_to_new_vars)
for k, v in iter_vars_to_flat_vars.items()
}
def extract_normalized_read_writes(
node: Union["FusedSchedulerNode", "SchedulerNode"],
) -> FusedNormalizedReadsWrites:
"""Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node."""
reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
all_output_names = node.get_buffer_names()
op_names = node.get_operation_names()
outputs = OrderedSet(
buf
for buf in all_output_names
if not V.graph.scheduler.can_buffer_be_removed_through_fusion(buf, op_names)
)
inputs = OrderedSet(dep.name for dep in node.read_writes.reads)
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
# lets use different prefix (`n`) to distinguish
(norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze(
pw_splits, red_splits, prefix="n"
)
node = node
pointwise_numel: sympy.Expr = node.group[1][0]
red_numel: sympy.Expr = node.group[1][1]
for n in list(node.get_nodes()):
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
continue
body = n._body
n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
for inp in inputs:
for expr in body.get_all_read_expr(inp):
n_reads[expr].add(inp)
for out in outputs:
for expr in body.get_all_write_expr(out):
n_writes[expr].add(out)
if not n_reads and not n_writes:
continue
(iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits(
n, pointwise_numel, red_numel
)
groups = pw_splits + red_splits
lengths = (n_pw_splits, (n_red_splits))
lengths = (
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
groups, lengths, red_numel
)
)
new_ranges, return_getters_groups = (
torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges(
groups, lengths
)
)
var_map = apply_var_mapping(
iter_vars,
red_vars,
norm_pw_vars,
norm_red_vars,
new_ranges,
return_getters_groups,
)
n_reads_new = {sympy_subs(read, var_map): v for read, v in n_reads.items()}
n_writes_new = {sympy_subs(write, var_map): v for write, v in n_writes.items()}
for expr, buf_names in n_reads_new.items():
reads[expr] |= buf_names
for expr, buf_names in n_writes_new.items():
writes[expr] |= buf_names
reads = {
V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items()
}
writes = {
V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items()
}
fused_out = FusedNormalizedReadsWrites(
norm_pw_vars, # type: ignore[arg-type]
norm_red_vars, # type: ignore[arg-type]
reads,
writes,
ranges,
)
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
return fused_out

View File

@ -186,6 +186,12 @@ register_artifact(
"Logs related to loop ordering",
off_by_default=True,
)
register_artifact(
"loop_tiling",
"Logs related to loop ordering",
off_by_default=True,
)
register_artifact(
"overlap",
"Detailed Inductor compute/comm overlap decisions",