Analyze coalesced mem (#153730)

Analyze memory expressions to see if they contain a coalescing symbol.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153730
Approved by: https://github.com/jansel
ghstack dependencies: #153723
This commit is contained in:
eellison
2025-06-02 16:44:56 -07:00
committed by PyTorch MergeBot
parent e9266f807a
commit 0adbde4d35
2 changed files with 200 additions and 7 deletions

View File

@ -2,16 +2,18 @@ import dataclasses
import functools
import itertools
import sys
from collections import defaultdict
from collections import Counter, defaultdict
from collections.abc import Iterable, Iterator
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
import sympy
import torch
from torch._inductor import config
from torch._inductor.dependencies import index_vars_no_squeeze
from torch._inductor.utils import sympy_product, sympy_subs
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .virtualized import V
@ -29,6 +31,40 @@ if TYPE_CHECKING:
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
def find_coalesced_var(
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
) -> Optional[sympy.Expr]:
"""
Try to find the symbol which coalesces this index
"""
top_level_terms = sympy.Add.make_args(index)
for v in var_ranges:
if v in top_level_terms:
return v
# Approximate analysis by evaluating at 1 and 0
variables: dict[sympy.Symbol, int] = {}
for v in index.free_symbols:
if v in var_ranges:
variables[v] = 0
else:
variables[v] = get_hint(v)
zero_index = sympy_subs(index, variables)
for v in var_ranges.keys():
variables[v] = 1
try:
new_val = sympy_subs(index, variables)
except ZeroDivisionError:
loop_tiling_log.info("zero division error %s %s", index, variables)
continue
if new_val - zero_index == 1:
return v
variables[v] = 0
return None
@dataclasses.dataclass(frozen=True)
class FusedNormalizedReadsWrites:
"""
@ -360,3 +396,85 @@ def extract_normalized_read_writes(
)
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
return fused_out
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
"""
Score addr according to its approximate size
"""
# TODO - deduplicate with candidate_tilings
var_sizes = []
for v in addr.free_symbols:
v_size = var_ranges.get(v, None)
# TODO - reason about indirect vars
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
var_sizes.append(v_size)
from .virtualized import V
return V.graph.sizevars.atomically_apply_size_hint(
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
)
def get_hint(v: Union[sympy.Expr, int]) -> int:
if isinstance(v, int):
return v
else:
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
@dataclasses.dataclass(frozen=True)
class CoalesceVarAnalysis:
coalesced_by_var: dict[sympy.Expr, int]
norm_read_writes: FusedNormalizedReadsWrites
def analyze_memory_coalescing(
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
) -> CoalesceVarAnalysis:
"""
Find variables that coalesce the reads and writes and score the total size.
If uncoalesced memory expressions are found, look for additionally tiling of variables
which will coalesce memory accesses.
For instance - for the following expression:
(32*p0) // 2048
Tiling p0 by 64 will make this expression coalesced.
"""
norm_read_writes = extract_normalized_read_writes(fused_node)
reads = norm_read_writes.reads
writes = norm_read_writes.writes
var_ranges = norm_read_writes.var_ranges
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
size = get_score(memory_expr, var_ranges)
# TODO - handle indirect
if size == 0:
continue
# todo - dtype size
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
byte_multipler = 0
for buf_name in buf_names:
if buf := V.graph.try_get_buffer(buf_name):
byte_multipler += buf.dtype.itemsize
if maybe_coalesced_var:
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
else:
uncoalesced_addrs[memory_expr] += size * byte_multipler
return CoalesceVarAnalysis(
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
)