mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses. For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
481 lines
16 KiB
Python
481 lines
16 KiB
Python
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from collections.abc import Iterable, Iterator
|
|
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.dependencies import index_vars_no_squeeze
|
|
from torch._inductor.utils import sympy_product, sympy_subs
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
|
|
|
from .virtualized import V
|
|
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
|
|
Split = tuple[sympy.Expr, ...]
|
|
|
|
loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
|
|
|
|
|
|
def find_coalesced_var(
|
|
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
|
) -> Optional[sympy.Expr]:
|
|
"""
|
|
Try to find the symbol which coalesces this index
|
|
"""
|
|
top_level_terms = sympy.Add.make_args(index)
|
|
for v in var_ranges:
|
|
if v in top_level_terms:
|
|
return v
|
|
|
|
# Approximate analysis by evaluating at 1 and 0
|
|
variables: dict[sympy.Symbol, int] = {}
|
|
for v in index.free_symbols:
|
|
if v in var_ranges:
|
|
variables[v] = 0
|
|
else:
|
|
variables[v] = get_hint(v)
|
|
|
|
zero_index = sympy_subs(index, variables)
|
|
for v in var_ranges.keys():
|
|
variables[v] = 1
|
|
try:
|
|
new_val = sympy_subs(index, variables)
|
|
except ZeroDivisionError:
|
|
loop_tiling_log.info("zero division error %s %s", index, variables)
|
|
continue
|
|
if new_val - zero_index == 1:
|
|
return v
|
|
variables[v] = 0
|
|
|
|
return None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FusedNormalizedReadsWrites:
|
|
"""
|
|
Normalized reads and writes for nodes in the same FusedSchedulerNode.
|
|
"""
|
|
|
|
index_vars: OrderedSet[sympy.Symbol]
|
|
reduce_vars: OrderedSet[sympy.Symbol]
|
|
reads: dict[sympy.Expr, OrderedSet[str]]
|
|
writes: dict[sympy.Expr, OrderedSet[str]]
|
|
var_ranges: dict[sympy.Symbol, int]
|
|
|
|
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode", pointwise_numel: sympy.Expr, red_numel: sympy.Expr
|
|
) -> tuple[tuple[list[sympy.Symbol], list[int]], tuple[list[sympy.Symbol], list[int]]]:
|
|
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
|
|
i = len(n._body.sizes[0]) - 1
|
|
prod = 1
|
|
while i >= 0:
|
|
prod *= n._body.sizes[0][i]
|
|
if prod == red_numel:
|
|
break
|
|
|
|
if i >= 0:
|
|
pw_splits = n._body.sizes[0][0:i]
|
|
iter_vars = n._body.iter_vars[0:i]
|
|
|
|
red_splits = n._body.sizes[0][i:]
|
|
red_vars = n._body.iter_vars[i:]
|
|
return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value]
|
|
|
|
# TODO - handle, not sure if possible
|
|
raise RuntimeError(
|
|
f"Unhandled node: size: {n._body.sizes}, pw: {pointwise_numel}, red: {red_numel}"
|
|
)
|
|
|
|
|
|
class NodeSplitGetter:
|
|
"""
|
|
Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
):
|
|
self.node = node
|
|
self.pointwise_numel: sympy.Expr = node.group[1][0]
|
|
self.red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet)
|
|
|
|
self.reduction_split: Split = ()
|
|
self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet()
|
|
|
|
fused_group = node.group[1]
|
|
for n in reversed(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
(_, n_pw_splits), (_, n_red_splits) = get_pw_red_splits(
|
|
n, self.pointwise_numel, self.red_numel
|
|
)
|
|
|
|
# fill in reduction size
|
|
n_pw_splits, n_red_splits = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
fused_group, (n_pw_splits, n_red_splits), self.red_numel
|
|
)
|
|
)
|
|
|
|
self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits))
|
|
|
|
# initially, we are just going to do a single reduction split since
|
|
# reduction tiling is off by default. even if we miss a reduction split,
|
|
# we can recover it in the split var analysis.
|
|
# TODO: an earlier version fo this code tried to iteratively try the maximum number
|
|
# of split vars, by iterating over both pointwise and reduction. but not worth
|
|
# the complexity yet.
|
|
|
|
if n_red_splits != ():
|
|
self.reduction_split = (sympy_product(n_red_splits),)
|
|
|
|
n_size = (tuple(n_pw_splits), tuple(n_red_splits))
|
|
self.all_node_sizes.add(n_size)
|
|
|
|
self.seen_pw_splits: OrderedSet[Split] = OrderedSet()
|
|
|
|
def get_node_splits(self) -> tuple[Split, Split]:
|
|
"""
|
|
Get a compatible pointwise, reduction split of the node
|
|
"""
|
|
|
|
if len(self.all_node_sizes) == 1:
|
|
return next(iter(self.all_node_sizes))
|
|
|
|
max_pw_split = max(self.pw_split_options.keys())
|
|
for pw_split_len in range(max_pw_split, 0, -1):
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
if out := self.try_split(pw_split, self.reduction_split):
|
|
return out
|
|
|
|
# combine dims for next round
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
for i in range(len(pw_split) - 1):
|
|
new_split = tuple(
|
|
pw_split[0:i]
|
|
+ (sympy_product(pw_split[i : i + 2]),)
|
|
+ pw_split[i + 2 :]
|
|
)
|
|
self.pw_split_options[len(new_split)].add(new_split)
|
|
|
|
# if for whatever reason we couldnt split above, return default split
|
|
return ((self.pointwise_numel,), (self.red_numel,))
|
|
|
|
def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]:
|
|
"""
|
|
See if this split is compatible, and potentially returning a longer split
|
|
than the input.
|
|
"""
|
|
|
|
from torch._inductor.codegen.simd import CantSplit, SIMDKernel
|
|
|
|
if pw in self.seen_pw_splits:
|
|
return None
|
|
self.seen_pw_splits.add(pw)
|
|
|
|
for n_pw, n_red in self.all_node_sizes:
|
|
try:
|
|
groups = pw + red
|
|
lengths = (n_pw, n_red)
|
|
splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths)
|
|
except CantSplit:
|
|
return None
|
|
|
|
assert len(getters) == 2
|
|
pw_group_splits = splits[: len(pw)]
|
|
# if we had to divide a variable into two to do this split,
|
|
# then lets try the larger, induced split.
|
|
# e.g. splitting (12, 2) into (2, 12) will split the first var into:
|
|
# (2, 6) and produce an overall split of (2, 6, 2)
|
|
flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits))
|
|
if flattened_pw_splits != pw:
|
|
if out := self.try_split(flattened_pw_splits, red):
|
|
return out
|
|
|
|
return pw, red
|
|
|
|
|
|
if sys.version_info >= (3, 10):
|
|
# On Python 3.10+ we can use zip(strict=True)
|
|
zip_equal = functools.partial(zip, strict=True)
|
|
else:
|
|
# Fallback for older versions
|
|
def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]:
|
|
"""
|
|
Zip two iterables, raising ValueError if their lengths differ.
|
|
"""
|
|
if len(it1) != len(it2):
|
|
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
|
|
return zip(it1, it2)
|
|
|
|
|
|
def apply_var_mapping(
|
|
iter_vars: list[sympy.Symbol],
|
|
red_vars: list[sympy.Symbol],
|
|
norm_pw_vars: list[sympy.Symbol],
|
|
norm_red_vars: list[sympy.Symbol],
|
|
new_ranges: list[list[sympy.Expr]],
|
|
return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]],
|
|
) -> dict[sympy.Symbol, sympy.Expr]:
|
|
"""Maps original variables to expressions using normalized variables."""
|
|
|
|
# the output of split_iteration_range is a new_ranges, return_getters_groups
|
|
# new_ranges is a flattened list of ranges corresponding to the new pw and red vars
|
|
# for example, taking in pw vars of range (6, 6) to normalized range [36],
|
|
# new_ranges would be [[6, 6]]
|
|
# There is a return_getter callable for each input iter_var and red_vars.
|
|
# if you flatten out all of the ranges, and create a variable for each index,
|
|
# then applying the flattening vars to the callables in return_getters_groups
|
|
# gives you the mapping from input vars -> flattened vars.
|
|
# From there, we can compute the output, normalized variables.
|
|
# For instance [6, 6] corresponding to flat vars v0, v1 will be
|
|
# v0 + 6 * v1
|
|
|
|
# Create flattened iteration variables
|
|
num_vars = sum(len(s) for s in new_ranges)
|
|
flat_vars = sympy.symbols(f"v_0:{num_vars}")
|
|
count = 0
|
|
|
|
if len(iter_vars) == 0 and len(red_vars) == 0:
|
|
return {}
|
|
|
|
assert len(new_ranges) == len(norm_pw_vars + norm_red_vars)
|
|
apply_groups = []
|
|
for group in return_getters_groups:
|
|
apply_groups.append([g(flat_vars) for g in group])
|
|
|
|
iter_vars_to_flat_vars = {}
|
|
for i, (group, var_group) in enumerate(
|
|
zip_equal(apply_groups, ((iter_vars, red_vars)))
|
|
):
|
|
# if the node has sizes (p0, 1) and the fused node is (p0, r0)
|
|
# the reduction var gets filled in for split_iteration_range
|
|
if len(group) != len(var_group):
|
|
assert i == 1
|
|
assert len(var_group) == 0
|
|
continue
|
|
|
|
iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)})
|
|
|
|
count = 0
|
|
flat_vars_to_new_vars = {}
|
|
for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars):
|
|
range_vars = []
|
|
for i in range(len(new_range)):
|
|
range_vars.append(flat_vars[count])
|
|
count += 1
|
|
|
|
prod = 1
|
|
for i in range(len(new_range) - 1, -1, -1):
|
|
flat_vars_to_new_vars[range_vars[i]] = new_var * prod
|
|
prod = new_range[i] * prod
|
|
|
|
return {
|
|
k: sympy_subs(v, flat_vars_to_new_vars)
|
|
for k, v in iter_vars_to_flat_vars.items()
|
|
}
|
|
|
|
|
|
def extract_normalized_read_writes(
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> FusedNormalizedReadsWrites:
|
|
"""Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node."""
|
|
reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
all_output_names = node.get_buffer_names()
|
|
op_names = node.get_operation_names()
|
|
outputs = OrderedSet(
|
|
buf
|
|
for buf in all_output_names
|
|
if not V.graph.scheduler.can_buffer_be_removed_through_fusion(buf, op_names)
|
|
)
|
|
inputs = OrderedSet(dep.name for dep in node.read_writes.reads)
|
|
|
|
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
|
|
|
|
# lets use different prefix (`n`) to distinguish
|
|
(norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze(
|
|
pw_splits, red_splits, prefix="n"
|
|
)
|
|
node = node
|
|
pointwise_numel: sympy.Expr = node.group[1][0]
|
|
red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
for n in list(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
body = n._body
|
|
n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
for inp in inputs:
|
|
for expr in body.get_all_read_expr(inp):
|
|
n_reads[expr].add(inp)
|
|
for out in outputs:
|
|
for expr in body.get_all_write_expr(out):
|
|
n_writes[expr].add(out)
|
|
|
|
if not n_reads and not n_writes:
|
|
continue
|
|
|
|
(iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits(
|
|
n, pointwise_numel, red_numel
|
|
)
|
|
|
|
groups = pw_splits + red_splits
|
|
lengths = (n_pw_splits, (n_red_splits))
|
|
lengths = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
groups, lengths, red_numel
|
|
)
|
|
)
|
|
new_ranges, return_getters_groups = (
|
|
torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges(
|
|
groups, lengths
|
|
)
|
|
)
|
|
var_map = apply_var_mapping(
|
|
iter_vars,
|
|
red_vars,
|
|
norm_pw_vars,
|
|
norm_red_vars,
|
|
new_ranges,
|
|
return_getters_groups,
|
|
)
|
|
|
|
n_reads_new = {sympy_subs(read, var_map): v for read, v in n_reads.items()}
|
|
n_writes_new = {sympy_subs(write, var_map): v for write, v in n_writes.items()}
|
|
|
|
for expr, buf_names in n_reads_new.items():
|
|
reads[expr] |= buf_names
|
|
|
|
for expr, buf_names in n_writes_new.items():
|
|
writes[expr] |= buf_names
|
|
|
|
reads = {
|
|
V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items()
|
|
}
|
|
writes = {
|
|
V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items()
|
|
}
|
|
|
|
fused_out = FusedNormalizedReadsWrites(
|
|
norm_pw_vars, # type: ignore[arg-type]
|
|
norm_red_vars, # type: ignore[arg-type]
|
|
reads,
|
|
writes,
|
|
ranges,
|
|
)
|
|
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
|
|
return fused_out
|
|
|
|
|
|
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
|
"""
|
|
Score addr according to its approximate size
|
|
"""
|
|
|
|
# TODO - deduplicate with candidate_tilings
|
|
var_sizes = []
|
|
for v in addr.free_symbols:
|
|
v_size = var_ranges.get(v, None)
|
|
# TODO - reason about indirect vars
|
|
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
|
var_sizes.append(v_size)
|
|
from .virtualized import V
|
|
|
|
return V.graph.sizevars.atomically_apply_size_hint(
|
|
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
|
|
)
|
|
|
|
|
|
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
|
if isinstance(v, int):
|
|
return v
|
|
else:
|
|
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CoalesceVarAnalysis:
|
|
coalesced_by_var: dict[sympy.Expr, int]
|
|
|
|
norm_read_writes: FusedNormalizedReadsWrites
|
|
|
|
|
|
def analyze_memory_coalescing(
|
|
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> CoalesceVarAnalysis:
|
|
"""
|
|
Find variables that coalesce the reads and writes and score the total size.
|
|
|
|
If uncoalesced memory expressions are found, look for additionally tiling of variables
|
|
which will coalesce memory accesses.
|
|
|
|
For instance - for the following expression:
|
|
|
|
(32*p0) // 2048
|
|
|
|
Tiling p0 by 64 will make this expression coalesced.
|
|
"""
|
|
|
|
norm_read_writes = extract_normalized_read_writes(fused_node)
|
|
|
|
reads = norm_read_writes.reads
|
|
writes = norm_read_writes.writes
|
|
var_ranges = norm_read_writes.var_ranges
|
|
|
|
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
|
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
|
|
|
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
|
|
size = get_score(memory_expr, var_ranges)
|
|
# TODO - handle indirect
|
|
if size == 0:
|
|
continue
|
|
|
|
# todo - dtype size
|
|
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
|
|
|
byte_multipler = 0
|
|
for buf_name in buf_names:
|
|
if buf := V.graph.try_get_buffer(buf_name):
|
|
byte_multipler += buf.dtype.itemsize
|
|
|
|
if maybe_coalesced_var:
|
|
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
|
else:
|
|
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
|
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|