Turn on new tiling by default (#154768)

Turning on in fbcode to come. Also updates `max_tiles` to have a default value of None. The existing tiling logic doesn't really handle max_tiles=3 well, but we do in the new tiling logic, so we default to 3 in the new logic and 2 elsewhere unless max_tiles has been explicitly set.

TB runners have been very unstable recently (do we need to bump batch size ?) but e.g. for a [recent torchbench](https://hud.pytorch.org/benchmark/torchbench/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Tue,%2027%20May%202025%2015:38:26%20GMT&stopTime=Tue,%2003%20Jun%202025%2015:38:26%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/803/head&lCommit=8480c220db4eb3c9e2b58d85a698d0a7113a6e37&rBranch=main&rCommit=0cd18ba1ca35d87916723d445c06664615dcae12) inference run we had 15 models with a lower execution time (i.g. green) and 2 models with higher (i.e.. red)

I am doing another run and will update here.

Dynamic shapes is not yet turned on because there are a lot of fixes to be done in splitting that don't work yet.. See:
```
(Pdb) p expr
((s25*s85)//32)
(Pdb) p FloorDiv(expr, expr)
((s25*s85)//(32*(((s25*s85)//32))))
```

and also - unbacked shape is not multiple of itself.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154768
Approved by: https://github.com/jansel
This commit is contained in:
eellison
2025-06-05 06:52:28 -07:00
committed by PyTorch MergeBot
parent a85ad55525
commit 7dcc77e422
5 changed files with 51 additions and 22 deletions

View File

@ -526,7 +526,7 @@ class LoopOrderingTest(TestCase):
"triton.unique_kernel_names": True,
"loop_ordering_after_fusion": True,
"triton.max_tiles": 3,
"test_configs.global_tiling_analysis": True,
"triton.coalesce_tiling_analysis": True,
}
)
@instantiate_parametrized_tests
@ -798,13 +798,14 @@ class MemoryCoalescingTest(MockSchedulerTest):
# coalesce twice as many bytes as first dimension
# if not downcasted
# if downcasted, should be equal, bc larger dtype size
# we also weight writes x 2
cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]]
t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]]
if not downcast_transposed_v:
self.assertEqual(cont_reads, t_reads * 2)
self.assertEqual(cont_reads, t_reads * 3)
else:
self.assertEqual(cont_reads, t_reads)
self.assertEqual(cont_reads, t_reads * 1.5)
return nodes
@ -908,8 +909,7 @@ layouts = ("cont", "NHWC", "T")
{
"triton.unique_kernel_names": True,
"loop_ordering_after_fusion": True,
"test_configs.global_tiling_analysis": True,
"triton.max_tiles": 3,
"triton.coalesce_tiling_analysis": True,
}
)
@instantiate_parametrized_tests

View File

@ -14133,6 +14133,8 @@ if RUN_GPU:
# it does not move the tensor constructor to cuda and keeps it on CPU.
self.assertFalse("empty_strided_cuda(()" in code)
# only uncoalesced without this :)
@config.patch("triton.coalesce_tiling_analysis", False)
@config.patch("triton.use_block_ptr", False)
def test_evict_last_non_coalesced_loads(self):
@torch.compile
@ -14183,6 +14185,7 @@ if RUN_GPU:
)
@config.patch("triton.use_block_ptr", True)
@config.patch("triton.coalesce_tiling_analysis", False)
def test_evict_last_non_coalesced_loads_block_ptr(self):
@torch.compile
def f(a, b):

View File

@ -86,6 +86,11 @@ pexpr = PythonPrinter().doprint
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
def get_max_tiles(default: int = 2) -> int:
max_tiles = torch._inductor.config.triton.max_tiles
return max_tiles if max_tiles is not None else default
@dataclasses.dataclass
class IterationRanges:
"""
@ -1354,7 +1359,7 @@ class SIMDScheduling(BaseScheduling):
nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]
if torch._inductor.config.test_configs.global_tiling_analysis:
if torch._inductor.config.triton.coalesce_tiling_analysis:
coalesce_analysis = analyze_memory_coalescing(node)
else:
coalesce_analysis = None
@ -1993,7 +1998,7 @@ class SIMDScheduling(BaseScheduling):
# Flatten leading dimensions, assigning labels to each dim.
for node_tiling in node_tilings:
num_leading_dims = max(0, len(node_tiling) - config.triton.max_tiles)
num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2))
first_trailing_dim = num_leading_dims + 1
collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim])
collapsed_splits = (collapsed_leading_dim,) + tuple(
@ -2165,7 +2170,7 @@ class SIMDScheduling(BaseScheduling):
)
)
if torch._inductor.config.triton.max_tiles == 3 and reduction_numel == 1:
if get_max_tiles(default=3) == 3 and reduction_numel == 1:
for vars_to_use in itertools.combinations(overlapping_iter_vars, 2):
score_split.append(
(
@ -2187,13 +2192,16 @@ class SIMDScheduling(BaseScheduling):
# add a slight penalty for longer tilings that dont increase score much,
# and are poor sizes
additional_tiling_penalty = 1.025
bad_size_additional_tiling_penalty = 1.025
good_size_tiling_penalty = 1.005
def score_mod(t):
score_factor = 1.0
for tile_size in t[0].tiling.values():
if not CandidateTiling.is_good_size(tile_size):
score_factor = score_factor / additional_tiling_penalty
score_factor = score_factor / bad_size_additional_tiling_penalty
else:
score_factor = score_factor / good_size_tiling_penalty
return -t[0].score * score_factor
@ -2204,7 +2212,7 @@ class SIMDScheduling(BaseScheduling):
):
# we always include default reduction numel == 1, dont include
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
if tiling_len > torch._inductor.config.triton.max_tiles:
if tiling_len > get_max_tiles(default=3):
perf_hint_log.info(
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
"set to %s. Consider increasing",
@ -2289,16 +2297,17 @@ class SIMDScheduling(BaseScheduling):
# # TODO: enable by default
if (
torch._inductor.config.test_configs.global_tiling_analysis
torch._inductor.config.triton.coalesce_tiling_analysis
and coalesce_analysis
and not config.triton.prefer_nd_tiling
):
return cls.compute_tiling_strategy(
node_schedule, numel, reduction_numel, coalesce_analysis
)
if (
not is_pointwise and not config.triton.tile_reductions
) or config.triton.max_tiles <= 1:
if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles(
default=2
) <= 1:
# Emit a perf hint in case we miss an opportunity to tile a reduction.
if perf_hint_log.level <= logging.WARNING:
for node in EnableReduction.filter(node_schedule):
@ -2333,7 +2342,7 @@ class SIMDScheduling(BaseScheduling):
for candidate_tiling, score in candidate_tiles.most_common()
]
if config.triton.max_tiles >= 3 and is_pointwise:
if get_max_tiles(default=2) >= 3 and is_pointwise:
# Consider adding a third dimension of tiling, but only
# when a1 is a multiple of b1; otherwise, you have a lot
# of stragglers which is annoying to generate code for.

View File

@ -1115,12 +1115,23 @@ class triton:
# Always load full blocks (rather than broadcasting inside the block)
dense_indexing = False
# TODO - enable by default
coalesce_tiling_analysis: bool = (
os.environ.get(
"TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
)
== "1"
)
# limit tiling dimensions
# - max_tiles=1 disables tiling
# - max_tiles=2 is the default
# - max_tiles=2
# - max_tiles=3 is experimental and may have bugs
# higher values are unsupported
max_tiles = 2
# We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
# Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
max_tiles: Optional[int] = None
# Prefer higher dimensional tilings. This simplifies indexing expressions, making
# it easier to identify block pointers.
@ -1681,9 +1692,6 @@ class test_configs:
graphsafe_rng_func_ignores_fallback_random = False
# TODO - temporary config before enabled by default
global_tiling_analysis: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -621,6 +621,9 @@ class VarTiling:
@dataclasses.dataclass(frozen=True)
class CoalesceVarAnalysis:
# Var -> Memory Score - not strictly the amount of memory
# because we multiply writes x2
# TODO: separate into dataclass that olds mem, dtype, is_write
coalesced_by_var: dict[sympy.Expr, int]
norm_read_writes: FusedNormalizedReadsWrites
@ -656,7 +659,10 @@ def analyze_memory_coalescing(
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
for is_read, (memory_expr, buf_names) in itertools.chain(
((True, item) for item in reads.items()),
((False, item) for item in writes.items()),
):
# skip memory deps with indirect vars - todo: better handling
indirect_expr = bool(
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
@ -676,6 +682,9 @@ def analyze_memory_coalescing(
if buf := V.graph.try_get_buffer(buf_name):
byte_multipler += buf.dtype.itemsize
# coalesced writes more important
byte_multipler *= 1 if is_read else 2
if maybe_coalesced_var:
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
else: