mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Analyze coalesced mem (#153730)
Analyze memory expressions to see if they contain a coalescing symbol. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153730 Approved by: https://github.com/jansel ghstack dependencies: #153723
This commit is contained in:
committed by
PyTorch MergeBot
parent
e9266f807a
commit
0adbde4d35
@ -2,16 +2,18 @@ import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.dependencies import index_vars_no_squeeze
|
||||
from torch._inductor.utils import sympy_product, sympy_subs
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
@ -29,6 +31,40 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
|
||||
|
||||
|
||||
def find_coalesced_var(
|
||||
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
||||
) -> Optional[sympy.Expr]:
|
||||
"""
|
||||
Try to find the symbol which coalesces this index
|
||||
"""
|
||||
top_level_terms = sympy.Add.make_args(index)
|
||||
for v in var_ranges:
|
||||
if v in top_level_terms:
|
||||
return v
|
||||
|
||||
# Approximate analysis by evaluating at 1 and 0
|
||||
variables: dict[sympy.Symbol, int] = {}
|
||||
for v in index.free_symbols:
|
||||
if v in var_ranges:
|
||||
variables[v] = 0
|
||||
else:
|
||||
variables[v] = get_hint(v)
|
||||
|
||||
zero_index = sympy_subs(index, variables)
|
||||
for v in var_ranges.keys():
|
||||
variables[v] = 1
|
||||
try:
|
||||
new_val = sympy_subs(index, variables)
|
||||
except ZeroDivisionError:
|
||||
loop_tiling_log.info("zero division error %s %s", index, variables)
|
||||
continue
|
||||
if new_val - zero_index == 1:
|
||||
return v
|
||||
variables[v] = 0
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FusedNormalizedReadsWrites:
|
||||
"""
|
||||
@ -360,3 +396,85 @@ def extract_normalized_read_writes(
|
||||
)
|
||||
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
|
||||
return fused_out
|
||||
|
||||
|
||||
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
||||
"""
|
||||
Score addr according to its approximate size
|
||||
"""
|
||||
|
||||
# TODO - deduplicate with candidate_tilings
|
||||
var_sizes = []
|
||||
for v in addr.free_symbols:
|
||||
v_size = var_ranges.get(v, None)
|
||||
# TODO - reason about indirect vars
|
||||
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
||||
var_sizes.append(v_size)
|
||||
from .virtualized import V
|
||||
|
||||
return V.graph.sizevars.atomically_apply_size_hint(
|
||||
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
|
||||
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
||||
if isinstance(v, int):
|
||||
return v
|
||||
else:
|
||||
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CoalesceVarAnalysis:
|
||||
coalesced_by_var: dict[sympy.Expr, int]
|
||||
|
||||
norm_read_writes: FusedNormalizedReadsWrites
|
||||
|
||||
|
||||
def analyze_memory_coalescing(
|
||||
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
|
||||
) -> CoalesceVarAnalysis:
|
||||
"""
|
||||
Find variables that coalesce the reads and writes and score the total size.
|
||||
|
||||
If uncoalesced memory expressions are found, look for additionally tiling of variables
|
||||
which will coalesce memory accesses.
|
||||
|
||||
For instance - for the following expression:
|
||||
|
||||
(32*p0) // 2048
|
||||
|
||||
Tiling p0 by 64 will make this expression coalesced.
|
||||
"""
|
||||
|
||||
norm_read_writes = extract_normalized_read_writes(fused_node)
|
||||
|
||||
reads = norm_read_writes.reads
|
||||
writes = norm_read_writes.writes
|
||||
var_ranges = norm_read_writes.var_ranges
|
||||
|
||||
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
||||
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
||||
|
||||
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
|
||||
size = get_score(memory_expr, var_ranges)
|
||||
# TODO - handle indirect
|
||||
if size == 0:
|
||||
continue
|
||||
|
||||
# todo - dtype size
|
||||
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
||||
|
||||
byte_multipler = 0
|
||||
for buf_name in buf_names:
|
||||
if buf := V.graph.try_get_buffer(buf_name):
|
||||
byte_multipler += buf.dtype.itemsize
|
||||
|
||||
if maybe_coalesced_var:
|
||||
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
||||
else:
|
||||
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
||||
|
||||
return CoalesceVarAnalysis(
|
||||
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
||||
)
|
||||
|
Reference in New Issue
Block a user