mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Analyze memory expressions to see if they contain a coalescing symbol. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153730 Approved by: https://github.com/jansel ghstack dependencies: #153723
481 lines
16 KiB
Python
481 lines
16 KiB
Python
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from collections.abc import Iterable, Iterator
|
|
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.dependencies import index_vars_no_squeeze
|
|
from torch._inductor.utils import sympy_product, sympy_subs
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
|
|
|
from .virtualized import V
|
|
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
|
|
Split = tuple[sympy.Expr, ...]
|
|
|
|
loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
|
|
|
|
|
|
def find_coalesced_var(
|
|
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
|
) -> Optional[sympy.Expr]:
|
|
"""
|
|
Try to find the symbol which coalesces this index
|
|
"""
|
|
top_level_terms = sympy.Add.make_args(index)
|
|
for v in var_ranges:
|
|
if v in top_level_terms:
|
|
return v
|
|
|
|
# Approximate analysis by evaluating at 1 and 0
|
|
variables: dict[sympy.Symbol, int] = {}
|
|
for v in index.free_symbols:
|
|
if v in var_ranges:
|
|
variables[v] = 0
|
|
else:
|
|
variables[v] = get_hint(v)
|
|
|
|
zero_index = sympy_subs(index, variables)
|
|
for v in var_ranges.keys():
|
|
variables[v] = 1
|
|
try:
|
|
new_val = sympy_subs(index, variables)
|
|
except ZeroDivisionError:
|
|
loop_tiling_log.info("zero division error %s %s", index, variables)
|
|
continue
|
|
if new_val - zero_index == 1:
|
|
return v
|
|
variables[v] = 0
|
|
|
|
return None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FusedNormalizedReadsWrites:
|
|
"""
|
|
Normalized reads and writes for nodes in the same FusedSchedulerNode.
|
|
"""
|
|
|
|
index_vars: OrderedSet[sympy.Symbol]
|
|
reduce_vars: OrderedSet[sympy.Symbol]
|
|
reads: dict[sympy.Expr, OrderedSet[str]]
|
|
writes: dict[sympy.Expr, OrderedSet[str]]
|
|
var_ranges: dict[sympy.Symbol, int]
|
|
|
|
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode", pointwise_numel: sympy.Expr, red_numel: sympy.Expr
|
|
) -> tuple[tuple[list[sympy.Symbol], list[int]], tuple[list[sympy.Symbol], list[int]]]:
|
|
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
|
|
i = len(n._body.sizes[0]) - 1
|
|
prod = 1
|
|
while i >= 0:
|
|
prod *= n._body.sizes[0][i]
|
|
if prod == red_numel:
|
|
break
|
|
|
|
if i >= 0:
|
|
pw_splits = n._body.sizes[0][0:i]
|
|
iter_vars = n._body.iter_vars[0:i]
|
|
|
|
red_splits = n._body.sizes[0][i:]
|
|
red_vars = n._body.iter_vars[i:]
|
|
return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value]
|
|
|
|
# TODO - handle, not sure if possible
|
|
raise RuntimeError(
|
|
f"Unhandled node: size: {n._body.sizes}, pw: {pointwise_numel}, red: {red_numel}"
|
|
)
|
|
|
|
|
|
class NodeSplitGetter:
|
|
"""
|
|
Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
):
|
|
self.node = node
|
|
self.pointwise_numel: sympy.Expr = node.group[1][0]
|
|
self.red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet)
|
|
|
|
self.reduction_split: Split = ()
|
|
self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet()
|
|
|
|
fused_group = node.group[1]
|
|
for n in reversed(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
(_, n_pw_splits), (_, n_red_splits) = get_pw_red_splits(
|
|
n, self.pointwise_numel, self.red_numel
|
|
)
|
|
|
|
# fill in reduction size
|
|
n_pw_splits, n_red_splits = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
fused_group, (n_pw_splits, n_red_splits), self.red_numel
|
|
)
|
|
)
|
|
|
|
self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits))
|
|
|
|
# initially, we are just going to do a single reduction split since
|
|
# reduction tiling is off by default. even if we miss a reduction split,
|
|
# we can recover it in the split var analysis.
|
|
# TODO: an earlier version fo this code tried to iteratively try the maximum number
|
|
# of split vars, by iterating over both pointwise and reduction. but not worth
|
|
# the complexity yet.
|
|
|
|
if n_red_splits != ():
|
|
self.reduction_split = (sympy_product(n_red_splits),)
|
|
|
|
n_size = (tuple(n_pw_splits), tuple(n_red_splits))
|
|
self.all_node_sizes.add(n_size)
|
|
|
|
self.seen_pw_splits: OrderedSet[Split] = OrderedSet()
|
|
|
|
def get_node_splits(self) -> tuple[Split, Split]:
|
|
"""
|
|
Get a compatible pointwise, reduction split of the node
|
|
"""
|
|
|
|
if len(self.all_node_sizes) == 1:
|
|
return next(iter(self.all_node_sizes))
|
|
|
|
max_pw_split = max(self.pw_split_options.keys())
|
|
for pw_split_len in range(max_pw_split, 0, -1):
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
if out := self.try_split(pw_split, self.reduction_split):
|
|
return out
|
|
|
|
# combine dims for next round
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
for i in range(len(pw_split) - 1):
|
|
new_split = tuple(
|
|
pw_split[0:i]
|
|
+ (sympy_product(pw_split[i : i + 2]),)
|
|
+ pw_split[i + 2 :]
|
|
)
|
|
self.pw_split_options[len(new_split)].add(new_split)
|
|
|
|
# if for whatever reason we couldnt split above, return default split
|
|
return ((self.pointwise_numel,), (self.red_numel,))
|
|
|
|
def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]:
|
|
"""
|
|
See if this split is compatible, and potentially returning a longer split
|
|
than the input.
|
|
"""
|
|
|
|
from torch._inductor.codegen.simd import CantSplit, SIMDKernel
|
|
|
|
if pw in self.seen_pw_splits:
|
|
return None
|
|
self.seen_pw_splits.add(pw)
|
|
|
|
for n_pw, n_red in self.all_node_sizes:
|
|
try:
|
|
groups = pw + red
|
|
lengths = (n_pw, n_red)
|
|
splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths)
|
|
except CantSplit:
|
|
return None
|
|
|
|
assert len(getters) == 2
|
|
pw_group_splits = splits[: len(pw)]
|
|
# if we had to divide a variable into two to do this split,
|
|
# then lets try the larger, induced split.
|
|
# e.g. splitting (12, 2) into (2, 12) will split the first var into:
|
|
# (2, 6) and produce an overall split of (2, 6, 2)
|
|
flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits))
|
|
if flattened_pw_splits != pw:
|
|
if out := self.try_split(flattened_pw_splits, red):
|
|
return out
|
|
|
|
return pw, red
|
|
|
|
|
|
if sys.version_info >= (3, 10):
|
|
# On Python 3.10+ we can use zip(strict=True)
|
|
zip_equal = functools.partial(zip, strict=True)
|
|
else:
|
|
# Fallback for older versions
|
|
def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]:
|
|
"""
|
|
Zip two iterables, raising ValueError if their lengths differ.
|
|
"""
|
|
if len(it1) != len(it2):
|
|
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
|
|
return zip(it1, it2)
|
|
|
|
|
|
def apply_var_mapping(
|
|
iter_vars: list[sympy.Symbol],
|
|
red_vars: list[sympy.Symbol],
|
|
norm_pw_vars: list[sympy.Symbol],
|
|
norm_red_vars: list[sympy.Symbol],
|
|
new_ranges: list[list[sympy.Expr]],
|
|
return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]],
|
|
) -> dict[sympy.Symbol, sympy.Expr]:
|
|
"""Maps original variables to expressions using normalized variables."""
|
|
|
|
# the output of split_iteration_range is a new_ranges, return_getters_groups
|
|
# new_ranges is a flattened list of ranges corresponding to the new pw and red vars
|
|
# for example, taking in pw vars of range (6, 6) to normalized range [36],
|
|
# new_ranges would be [[6, 6]]
|
|
# There is a return_getter callable for each input iter_var and red_vars.
|
|
# if you flatten out all of the ranges, and create a variable for each index,
|
|
# then applying the flattening vars to the callables in return_getters_groups
|
|
# gives you the mapping from input vars -> flattened vars.
|
|
# From there, we can compute the output, normalized variables.
|
|
# For instance [6, 6] corresponding to flat vars v0, v1 will be
|
|
# v0 + 6 * v1
|
|
|
|
# Create flattened iteration variables
|
|
num_vars = sum(len(s) for s in new_ranges)
|
|
flat_vars = sympy.symbols(f"v_0:{num_vars}")
|
|
count = 0
|
|
|
|
if len(iter_vars) == 0 and len(red_vars) == 0:
|
|
return {}
|
|
|
|
assert len(new_ranges) == len(norm_pw_vars + norm_red_vars)
|
|
apply_groups = []
|
|
for group in return_getters_groups:
|
|
apply_groups.append([g(flat_vars) for g in group])
|
|
|
|
iter_vars_to_flat_vars = {}
|
|
for i, (group, var_group) in enumerate(
|
|
zip_equal(apply_groups, ((iter_vars, red_vars)))
|
|
):
|
|
# if the node has sizes (p0, 1) and the fused node is (p0, r0)
|
|
# the reduction var gets filled in for split_iteration_range
|
|
if len(group) != len(var_group):
|
|
assert i == 1
|
|
assert len(var_group) == 0
|
|
continue
|
|
|
|
iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)})
|
|
|
|
count = 0
|
|
flat_vars_to_new_vars = {}
|
|
for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars):
|
|
range_vars = []
|
|
for i in range(len(new_range)):
|
|
range_vars.append(flat_vars[count])
|
|
count += 1
|
|
|
|
prod = 1
|
|
for i in range(len(new_range) - 1, -1, -1):
|
|
flat_vars_to_new_vars[range_vars[i]] = new_var * prod
|
|
prod = new_range[i] * prod
|
|
|
|
return {
|
|
k: sympy_subs(v, flat_vars_to_new_vars)
|
|
for k, v in iter_vars_to_flat_vars.items()
|
|
}
|
|
|
|
|
|
def extract_normalized_read_writes(
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> FusedNormalizedReadsWrites:
|
|
"""Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node."""
|
|
reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
all_output_names = node.get_buffer_names()
|
|
op_names = node.get_operation_names()
|
|
outputs = OrderedSet(
|
|
buf
|
|
for buf in all_output_names
|
|
if not V.graph.scheduler.can_buffer_be_removed_through_fusion(buf, op_names)
|
|
)
|
|
inputs = OrderedSet(dep.name for dep in node.read_writes.reads)
|
|
|
|
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
|
|
|
|
# lets use different prefix (`n`) to distinguish
|
|
(norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze(
|
|
pw_splits, red_splits, prefix="n"
|
|
)
|
|
node = node
|
|
pointwise_numel: sympy.Expr = node.group[1][0]
|
|
red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
for n in list(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
body = n._body
|
|
n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
for inp in inputs:
|
|
for expr in body.get_all_read_expr(inp):
|
|
n_reads[expr].add(inp)
|
|
for out in outputs:
|
|
for expr in body.get_all_write_expr(out):
|
|
n_writes[expr].add(out)
|
|
|
|
if not n_reads and not n_writes:
|
|
continue
|
|
|
|
(iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits(
|
|
n, pointwise_numel, red_numel
|
|
)
|
|
|
|
groups = pw_splits + red_splits
|
|
lengths = (n_pw_splits, (n_red_splits))
|
|
lengths = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
groups, lengths, red_numel
|
|
)
|
|
)
|
|
new_ranges, return_getters_groups = (
|
|
torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges(
|
|
groups, lengths
|
|
)
|
|
)
|
|
var_map = apply_var_mapping(
|
|
iter_vars,
|
|
red_vars,
|
|
norm_pw_vars,
|
|
norm_red_vars,
|
|
new_ranges,
|
|
return_getters_groups,
|
|
)
|
|
|
|
n_reads_new = {sympy_subs(read, var_map): v for read, v in n_reads.items()}
|
|
n_writes_new = {sympy_subs(write, var_map): v for write, v in n_writes.items()}
|
|
|
|
for expr, buf_names in n_reads_new.items():
|
|
reads[expr] |= buf_names
|
|
|
|
for expr, buf_names in n_writes_new.items():
|
|
writes[expr] |= buf_names
|
|
|
|
reads = {
|
|
V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items()
|
|
}
|
|
writes = {
|
|
V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items()
|
|
}
|
|
|
|
fused_out = FusedNormalizedReadsWrites(
|
|
norm_pw_vars, # type: ignore[arg-type]
|
|
norm_red_vars, # type: ignore[arg-type]
|
|
reads,
|
|
writes,
|
|
ranges,
|
|
)
|
|
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
|
|
return fused_out
|
|
|
|
|
|
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
|
"""
|
|
Score addr according to its approximate size
|
|
"""
|
|
|
|
# TODO - deduplicate with candidate_tilings
|
|
var_sizes = []
|
|
for v in addr.free_symbols:
|
|
v_size = var_ranges.get(v, None)
|
|
# TODO - reason about indirect vars
|
|
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
|
var_sizes.append(v_size)
|
|
from .virtualized import V
|
|
|
|
return V.graph.sizevars.atomically_apply_size_hint(
|
|
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
|
|
)
|
|
|
|
|
|
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
|
if isinstance(v, int):
|
|
return v
|
|
else:
|
|
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CoalesceVarAnalysis:
|
|
coalesced_by_var: dict[sympy.Expr, int]
|
|
|
|
norm_read_writes: FusedNormalizedReadsWrites
|
|
|
|
|
|
def analyze_memory_coalescing(
|
|
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> CoalesceVarAnalysis:
|
|
"""
|
|
Find variables that coalesce the reads and writes and score the total size.
|
|
|
|
If uncoalesced memory expressions are found, look for additionally tiling of variables
|
|
which will coalesce memory accesses.
|
|
|
|
For instance - for the following expression:
|
|
|
|
(32*p0) // 2048
|
|
|
|
Tiling p0 by 64 will make this expression coalesced.
|
|
"""
|
|
|
|
norm_read_writes = extract_normalized_read_writes(fused_node)
|
|
|
|
reads = norm_read_writes.reads
|
|
writes = norm_read_writes.writes
|
|
var_ranges = norm_read_writes.var_ranges
|
|
|
|
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
|
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
|
|
|
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
|
|
size = get_score(memory_expr, var_ranges)
|
|
# TODO - handle indirect
|
|
if size == 0:
|
|
continue
|
|
|
|
# todo - dtype size
|
|
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
|
|
|
byte_multipler = 0
|
|
for buf_name in buf_names:
|
|
if buf := V.graph.try_get_buffer(buf_name):
|
|
byte_multipler += buf.dtype.itemsize
|
|
|
|
if maybe_coalesced_var:
|
|
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
|
else:
|
|
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
|
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|