mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Turn on new tiling by default (#154768)
Turning on in fbcode to come. Also updates `max_tiles` to have a default value of None. The existing tiling logic doesn't really handle max_tiles=3 well, but we do in the new tiling logic, so we default to 3 in the new logic and 2 elsewhere unless max_tiles has been explicitly set. TB runners have been very unstable recently (do we need to bump batch size ?) but e.g. for a [recent torchbench](https://hud.pytorch.org/benchmark/torchbench/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Tue,%2027%20May%202025%2015:38:26%20GMT&stopTime=Tue,%2003%20Jun%202025%2015:38:26%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(a100)&lBranch=gh/eellison/803/head&lCommit=8480c220db4eb3c9e2b58d85a698d0a7113a6e37&rBranch=main&rCommit=0cd18ba1ca35d87916723d445c06664615dcae12) inference run we had 15 models with a lower execution time (i.g. green) and 2 models with higher (i.e.. red) I am doing another run and will update here. Dynamic shapes is not yet turned on because there are a lot of fixes to be done in splitting that don't work yet.. See: ``` (Pdb) p expr ((s25*s85)//32) (Pdb) p FloorDiv(expr, expr) ((s25*s85)//(32*(((s25*s85)//32)))) ``` and also - unbacked shape is not multiple of itself. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154768 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
a85ad55525
commit
7dcc77e422
@ -526,7 +526,7 @@ class LoopOrderingTest(TestCase):
|
||||
"triton.unique_kernel_names": True,
|
||||
"loop_ordering_after_fusion": True,
|
||||
"triton.max_tiles": 3,
|
||||
"test_configs.global_tiling_analysis": True,
|
||||
"triton.coalesce_tiling_analysis": True,
|
||||
}
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
@ -798,13 +798,14 @@ class MemoryCoalescingTest(MockSchedulerTest):
|
||||
# coalesce twice as many bytes as first dimension
|
||||
# if not downcasted
|
||||
# if downcasted, should be equal, bc larger dtype size
|
||||
# we also weight writes x 2
|
||||
cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]]
|
||||
t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]]
|
||||
|
||||
if not downcast_transposed_v:
|
||||
self.assertEqual(cont_reads, t_reads * 2)
|
||||
self.assertEqual(cont_reads, t_reads * 3)
|
||||
else:
|
||||
self.assertEqual(cont_reads, t_reads)
|
||||
self.assertEqual(cont_reads, t_reads * 1.5)
|
||||
|
||||
return nodes
|
||||
|
||||
@ -908,8 +909,7 @@ layouts = ("cont", "NHWC", "T")
|
||||
{
|
||||
"triton.unique_kernel_names": True,
|
||||
"loop_ordering_after_fusion": True,
|
||||
"test_configs.global_tiling_analysis": True,
|
||||
"triton.max_tiles": 3,
|
||||
"triton.coalesce_tiling_analysis": True,
|
||||
}
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
|
@ -14133,6 +14133,8 @@ if RUN_GPU:
|
||||
# it does not move the tensor constructor to cuda and keeps it on CPU.
|
||||
self.assertFalse("empty_strided_cuda(()" in code)
|
||||
|
||||
# only uncoalesced without this :)
|
||||
@config.patch("triton.coalesce_tiling_analysis", False)
|
||||
@config.patch("triton.use_block_ptr", False)
|
||||
def test_evict_last_non_coalesced_loads(self):
|
||||
@torch.compile
|
||||
@ -14183,6 +14185,7 @@ if RUN_GPU:
|
||||
)
|
||||
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
@config.patch("triton.coalesce_tiling_analysis", False)
|
||||
def test_evict_last_non_coalesced_loads_block_ptr(self):
|
||||
@torch.compile
|
||||
def f(a, b):
|
||||
|
@ -86,6 +86,11 @@ pexpr = PythonPrinter().doprint
|
||||
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
|
||||
|
||||
|
||||
def get_max_tiles(default: int = 2) -> int:
|
||||
max_tiles = torch._inductor.config.triton.max_tiles
|
||||
return max_tiles if max_tiles is not None else default
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IterationRanges:
|
||||
"""
|
||||
@ -1354,7 +1359,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]
|
||||
|
||||
if torch._inductor.config.test_configs.global_tiling_analysis:
|
||||
if torch._inductor.config.triton.coalesce_tiling_analysis:
|
||||
coalesce_analysis = analyze_memory_coalescing(node)
|
||||
else:
|
||||
coalesce_analysis = None
|
||||
@ -1993,7 +1998,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
# Flatten leading dimensions, assigning labels to each dim.
|
||||
for node_tiling in node_tilings:
|
||||
num_leading_dims = max(0, len(node_tiling) - config.triton.max_tiles)
|
||||
num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2))
|
||||
first_trailing_dim = num_leading_dims + 1
|
||||
collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim])
|
||||
collapsed_splits = (collapsed_leading_dim,) + tuple(
|
||||
@ -2165,7 +2170,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
)
|
||||
)
|
||||
|
||||
if torch._inductor.config.triton.max_tiles == 3 and reduction_numel == 1:
|
||||
if get_max_tiles(default=3) == 3 and reduction_numel == 1:
|
||||
for vars_to_use in itertools.combinations(overlapping_iter_vars, 2):
|
||||
score_split.append(
|
||||
(
|
||||
@ -2187,13 +2192,16 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
# add a slight penalty for longer tilings that dont increase score much,
|
||||
# and are poor sizes
|
||||
additional_tiling_penalty = 1.025
|
||||
bad_size_additional_tiling_penalty = 1.025
|
||||
good_size_tiling_penalty = 1.005
|
||||
|
||||
def score_mod(t):
|
||||
score_factor = 1.0
|
||||
for tile_size in t[0].tiling.values():
|
||||
if not CandidateTiling.is_good_size(tile_size):
|
||||
score_factor = score_factor / additional_tiling_penalty
|
||||
score_factor = score_factor / bad_size_additional_tiling_penalty
|
||||
else:
|
||||
score_factor = score_factor / good_size_tiling_penalty
|
||||
|
||||
return -t[0].score * score_factor
|
||||
|
||||
@ -2204,7 +2212,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
):
|
||||
# we always include default reduction numel == 1, dont include
|
||||
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
|
||||
if tiling_len > torch._inductor.config.triton.max_tiles:
|
||||
if tiling_len > get_max_tiles(default=3):
|
||||
perf_hint_log.info(
|
||||
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
|
||||
"set to %s. Consider increasing",
|
||||
@ -2289,16 +2297,17 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
# # TODO: enable by default
|
||||
if (
|
||||
torch._inductor.config.test_configs.global_tiling_analysis
|
||||
torch._inductor.config.triton.coalesce_tiling_analysis
|
||||
and coalesce_analysis
|
||||
and not config.triton.prefer_nd_tiling
|
||||
):
|
||||
return cls.compute_tiling_strategy(
|
||||
node_schedule, numel, reduction_numel, coalesce_analysis
|
||||
)
|
||||
|
||||
if (
|
||||
not is_pointwise and not config.triton.tile_reductions
|
||||
) or config.triton.max_tiles <= 1:
|
||||
if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles(
|
||||
default=2
|
||||
) <= 1:
|
||||
# Emit a perf hint in case we miss an opportunity to tile a reduction.
|
||||
if perf_hint_log.level <= logging.WARNING:
|
||||
for node in EnableReduction.filter(node_schedule):
|
||||
@ -2333,7 +2342,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
for candidate_tiling, score in candidate_tiles.most_common()
|
||||
]
|
||||
|
||||
if config.triton.max_tiles >= 3 and is_pointwise:
|
||||
if get_max_tiles(default=2) >= 3 and is_pointwise:
|
||||
# Consider adding a third dimension of tiling, but only
|
||||
# when a1 is a multiple of b1; otherwise, you have a lot
|
||||
# of stragglers which is annoying to generate code for.
|
||||
|
@ -1115,12 +1115,23 @@ class triton:
|
||||
# Always load full blocks (rather than broadcasting inside the block)
|
||||
dense_indexing = False
|
||||
|
||||
# TODO - enable by default
|
||||
coalesce_tiling_analysis: bool = (
|
||||
os.environ.get(
|
||||
"TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
|
||||
)
|
||||
== "1"
|
||||
)
|
||||
|
||||
# limit tiling dimensions
|
||||
# - max_tiles=1 disables tiling
|
||||
# - max_tiles=2 is the default
|
||||
# - max_tiles=2
|
||||
# - max_tiles=3 is experimental and may have bugs
|
||||
# higher values are unsupported
|
||||
max_tiles = 2
|
||||
|
||||
# We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
|
||||
# Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
|
||||
max_tiles: Optional[int] = None
|
||||
|
||||
# Prefer higher dimensional tilings. This simplifies indexing expressions, making
|
||||
# it easier to identify block pointers.
|
||||
@ -1681,9 +1692,6 @@ class test_configs:
|
||||
|
||||
graphsafe_rng_func_ignores_fallback_random = False
|
||||
|
||||
# TODO - temporary config before enabled by default
|
||||
global_tiling_analysis: bool = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
@ -621,6 +621,9 @@ class VarTiling:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CoalesceVarAnalysis:
|
||||
# Var -> Memory Score - not strictly the amount of memory
|
||||
# because we multiply writes x2
|
||||
# TODO: separate into dataclass that olds mem, dtype, is_write
|
||||
coalesced_by_var: dict[sympy.Expr, int]
|
||||
|
||||
norm_read_writes: FusedNormalizedReadsWrites
|
||||
@ -656,7 +659,10 @@ def analyze_memory_coalescing(
|
||||
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
||||
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
||||
|
||||
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
|
||||
for is_read, (memory_expr, buf_names) in itertools.chain(
|
||||
((True, item) for item in reads.items()),
|
||||
((False, item) for item in writes.items()),
|
||||
):
|
||||
# skip memory deps with indirect vars - todo: better handling
|
||||
indirect_expr = bool(
|
||||
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
|
||||
@ -676,6 +682,9 @@ def analyze_memory_coalescing(
|
||||
if buf := V.graph.try_get_buffer(buf_name):
|
||||
byte_multipler += buf.dtype.itemsize
|
||||
|
||||
# coalesced writes more important
|
||||
byte_multipler *= 1 if is_read else 2
|
||||
|
||||
if maybe_coalesced_var:
|
||||
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
||||
else:
|
||||
|
Reference in New Issue
Block a user