mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
751 lines
25 KiB
Python
751 lines
25 KiB
Python
import dataclasses
|
|
import itertools
|
|
from collections import Counter, defaultdict
|
|
from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.dependencies import index_vars_no_squeeze
|
|
from torch._inductor.utils import sympy_product, sympy_subs
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._sympy.functions import Identity
|
|
from torch.utils._sympy.solve import try_solve
|
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
|
|
|
from .virtualized import V
|
|
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
|
|
Split = tuple[sympy.Expr, ...]
|
|
VarsAndRanges = tuple[list[sympy.Symbol], list[sympy.Expr]]
|
|
|
|
|
|
loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling")
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
|
|
|
|
|
|
def solve_for_zero(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
"""
|
|
Given an expr with a single free symbol, solve for a constant relation that would make
|
|
this expression 0.
|
|
"""
|
|
if expr.is_constant():
|
|
return None
|
|
elif isinstance(expr, FloorDiv):
|
|
return None
|
|
|
|
assert len(expr.free_symbols) == 1
|
|
free_symbol = next(iter(expr.free_symbols))
|
|
if isinstance(expr, ModularIndexing):
|
|
out = try_solve(sympy.Eq(expr.args[0], expr.args[2]), free_symbol)
|
|
else:
|
|
out = try_solve(sympy.Eq(expr, 0), free_symbol)
|
|
if not out or not out[1].is_constant():
|
|
return None
|
|
return out[1]
|
|
|
|
|
|
def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
"""
|
|
Giving an expr with a single free symbol, try to find a tiling that would
|
|
make the expression coalesced with respect to that symbol.
|
|
|
|
Tiling an expression `x` by `y` means that the expression will now be indexed
|
|
by both the original (x) and by (x * y). So we are looking for a
|
|
multiplicative factor that will make ((x + 1) * y) - (x * y) == 1.
|
|
|
|
To simplify things for sympy, we'll try just x * y == 1, check x(1) and x(0).
|
|
"""
|
|
|
|
if len(expr.free_symbols) == 0:
|
|
return None
|
|
|
|
free_symbol = next(iter(expr.free_symbols))
|
|
|
|
def _solve_simple_expr(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
assert not expr.has(ModularIndexing) and not expr.has(FloorDiv)
|
|
if len(expr.free_symbols) != 1:
|
|
return None
|
|
|
|
out = try_solve(sympy.Eq(expr, 1), free_symbol)
|
|
if not out or not out[1].is_constant():
|
|
return None
|
|
return out[1]
|
|
|
|
# Sympy solving is very limited with ModularIndexing and FloorDiv,
|
|
# but good otherwise.
|
|
if not expr.has(ModularIndexing) and not expr.has(FloorDiv):
|
|
return _solve_simple_expr(expr)
|
|
|
|
required_values = []
|
|
eq_1_expressions = []
|
|
|
|
# very piecemeal solution if ModularIndexing or FloorDiv involved.
|
|
# Look for terms we'll try to make 0, and then other terms we'll try to make 1.
|
|
# Expand as needed.
|
|
for arg in sympy.Add.make_args(expr):
|
|
# Try to make mul terms 0
|
|
if isinstance(arg, sympy.Mul):
|
|
seen = False
|
|
# TODO - only need one of these to be solvable to zero
|
|
#
|
|
for mul_arg in arg.args:
|
|
out = solve_for_zero(mul_arg)
|
|
if out is None:
|
|
continue
|
|
|
|
assert out.is_constant()
|
|
seen = True
|
|
required_values.append(out)
|
|
|
|
if not seen:
|
|
return None
|
|
else:
|
|
eq_1_expressions.append(arg)
|
|
|
|
if not eq_1_expressions:
|
|
return None
|
|
|
|
eq_1_expr = sum(eq_1_expressions)
|
|
|
|
def indexing_div_rep(
|
|
x: sympy.Expr,
|
|
y: sympy.Expr,
|
|
z: Optional[sympy.Expr] = None,
|
|
) -> sympy.Expr:
|
|
return x / y
|
|
|
|
# For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv
|
|
# then check later
|
|
# pyrefly: ignore # missing-attribute
|
|
eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace(
|
|
FloorDiv, indexing_div_rep
|
|
)
|
|
|
|
out = _solve_simple_expr(eq_1_expr_simplified)
|
|
# since we approximated FloorDiv/ModularIndexing, double check here
|
|
if not out or sympy_subs(eq_1_expr, {free_symbol: out}) != 1:
|
|
return None
|
|
|
|
required_values.append(out)
|
|
|
|
if len(OrderedSet(required_values)) == 1:
|
|
return required_values[0]
|
|
|
|
return None
|
|
|
|
|
|
def find_coalesced_var(
|
|
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
|
) -> Optional[sympy.Expr]:
|
|
"""
|
|
Try to find the symbol which coalesces this index
|
|
"""
|
|
top_level_terms = sympy.Add.make_args(index)
|
|
for v in var_ranges:
|
|
if v in top_level_terms:
|
|
return v
|
|
|
|
# Approximate analysis by evaluating at 1 and 0
|
|
variables: dict[sympy.Symbol, int] = {}
|
|
for v in index.free_symbols:
|
|
if v in var_ranges:
|
|
variables[v] = 0
|
|
else:
|
|
variables[v] = get_hint(v)
|
|
|
|
zero_index = sympy_subs(index, variables)
|
|
for v in var_ranges.keys():
|
|
variables[v] = 1
|
|
try:
|
|
new_val = sympy_subs(index, variables)
|
|
except ZeroDivisionError:
|
|
loop_tiling_log.info("zero division error %s %s", index, variables)
|
|
continue
|
|
if new_val - zero_index == 1:
|
|
variables[v] = 2
|
|
# in some more complex expressions, 0->1 will be coalesced,
|
|
# but not 1->2
|
|
if (sympy_subs(index, variables) - new_val) == 1:
|
|
return v
|
|
variables[v] = 0
|
|
|
|
return None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FusedNormalizedReadsWrites:
|
|
"""
|
|
Normalized reads and writes for nodes in the same FusedSchedulerNode.
|
|
"""
|
|
|
|
index_vars: OrderedSet[sympy.Symbol]
|
|
reduce_vars: OrderedSet[sympy.Symbol]
|
|
reads: dict[sympy.Expr, OrderedSet[str]]
|
|
writes: dict[sympy.Expr, OrderedSet[str]]
|
|
var_ranges: dict[sympy.Symbol, int]
|
|
|
|
|
|
@overload
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: Literal[True],
|
|
) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: ...
|
|
|
|
|
|
@overload
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: Literal[False] = False,
|
|
) -> tuple[VarsAndRanges, VarsAndRanges]: ...
|
|
|
|
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: bool = False,
|
|
) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]:
|
|
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
|
|
i = len(n._body.sizes[0]) - 1
|
|
prod = 1
|
|
while i >= 0:
|
|
prod *= n._body.sizes[0][i]
|
|
if prod == red_numel:
|
|
break
|
|
i -= 1
|
|
|
|
if i >= 0:
|
|
pw_splits = n._body.sizes[0][0:i]
|
|
iter_vars = n._body.iter_vars[0:i]
|
|
|
|
red_splits = n._body.sizes[0][i:]
|
|
red_vars = n._body.iter_vars[i:]
|
|
return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value]
|
|
|
|
if none_if_not_divisible:
|
|
return None
|
|
else:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
|
|
class NodeSplitGetter:
|
|
"""
|
|
Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
):
|
|
self.node = node
|
|
self.pointwise_numel: sympy.Expr = node.group[1][0]
|
|
self.red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet)
|
|
|
|
self.reduction_split: Split = ()
|
|
self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet()
|
|
|
|
fused_group = node.group[1]
|
|
for n in reversed(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
# if we can't split the pw ranges into a (pw, red) split,
|
|
# dont add as a split option, but do make sure we check that this size
|
|
# is splittable
|
|
maybe_splits = get_pw_red_splits(
|
|
n, self.pointwise_numel, self.red_numel, none_if_not_divisible=True
|
|
)
|
|
if maybe_splits is None:
|
|
self.all_node_sizes.add(n._body.sizes)
|
|
continue
|
|
|
|
(_, n_pw_splits), (_, n_red_splits) = maybe_splits
|
|
|
|
# fill in reduction size
|
|
n_pw_splits, n_red_splits = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
fused_group, (n_pw_splits, n_red_splits), self.red_numel
|
|
)
|
|
)
|
|
|
|
self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits))
|
|
|
|
# initially, we are just going to do a single reduction split since
|
|
# reduction tiling is off by default. even if we miss a reduction split,
|
|
# we can recover it in the split var analysis.
|
|
# TODO: an earlier version for this code tried to iteratively try the maximum number
|
|
# of split vars, by iterating over both pointwise and reduction. but not worth
|
|
# the complexity yet.
|
|
|
|
if n_red_splits != ():
|
|
self.reduction_split = (sympy_product(n_red_splits),)
|
|
|
|
n_size = (tuple(n_pw_splits), tuple(n_red_splits))
|
|
self.all_node_sizes.add(n_size)
|
|
|
|
self.seen_pw_splits: OrderedSet[Split] = OrderedSet()
|
|
|
|
def get_node_splits(self) -> tuple[Split, Split]:
|
|
"""
|
|
Get a compatible pointwise, reduction split of the node
|
|
"""
|
|
|
|
if len(self.all_node_sizes) == 1:
|
|
return next(iter(self.all_node_sizes))
|
|
|
|
max_pw_split = max(self.pw_split_options.keys())
|
|
for pw_split_len in range(max_pw_split, 0, -1):
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
if out := self.try_split(pw_split, self.reduction_split):
|
|
return out
|
|
|
|
# combine dims for next round
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
for i in range(len(pw_split) - 1):
|
|
new_split = tuple(
|
|
pw_split[0:i]
|
|
+ (sympy_product(pw_split[i : i + 2]),)
|
|
+ pw_split[i + 2 :]
|
|
)
|
|
self.pw_split_options[len(new_split)].add(new_split)
|
|
|
|
# if for whatever reason we couldn't split above, return default split
|
|
return ((self.pointwise_numel,), (self.red_numel,))
|
|
|
|
def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]:
|
|
"""
|
|
See if this split is compatible, and potentially returning a longer split
|
|
than the input.
|
|
"""
|
|
|
|
from torch._inductor.codegen.simd import CantSplit, SIMDKernel
|
|
|
|
if pw in self.seen_pw_splits:
|
|
return None
|
|
self.seen_pw_splits.add(pw)
|
|
|
|
for n_pw, n_red in self.all_node_sizes:
|
|
try:
|
|
groups = pw + red
|
|
lengths = (n_pw, n_red)
|
|
splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths)
|
|
except CantSplit:
|
|
return None
|
|
|
|
assert len(getters) == 2
|
|
pw_group_splits = splits[: len(pw)]
|
|
# if we had to divide a variable into two to do this split,
|
|
# then lets try the larger, induced split.
|
|
# e.g. splitting (12, 2) into (2, 12) will split the first var into:
|
|
# (2, 6) and produce an overall split of (2, 6, 2)
|
|
flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits))
|
|
if flattened_pw_splits != pw:
|
|
if out := self.try_split(flattened_pw_splits, red):
|
|
return out
|
|
|
|
return pw, red
|
|
|
|
|
|
def apply_var_mapping(
|
|
iter_vars: list[sympy.Symbol],
|
|
red_vars: list[sympy.Symbol],
|
|
norm_pw_vars: list[sympy.Symbol],
|
|
norm_red_vars: list[sympy.Symbol],
|
|
new_ranges: list[list[sympy.Expr]],
|
|
return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]],
|
|
) -> dict[sympy.Symbol, sympy.Expr]:
|
|
"""Maps original variables to expressions using normalized variables."""
|
|
|
|
# the output of split_iteration_range is a new_ranges, return_getters_groups
|
|
# new_ranges is a flattened list of ranges corresponding to the new pw and red vars
|
|
# for example, taking in pw vars of range (6, 6) to normalized range [36],
|
|
# new_ranges would be [[6, 6]]
|
|
# There is a return_getter callable for each input iter_var and red_vars.
|
|
# if you flatten out all of the ranges, and create a variable for each index,
|
|
# then applying the flattening vars to the callables in return_getters_groups
|
|
# gives you the mapping from input vars -> flattened vars.
|
|
# From there, we can compute the output, normalized variables.
|
|
# For instance [6, 6] corresponding to flat vars v0, v1 will be
|
|
# v0 + 6 * v1
|
|
|
|
# Create flattened iteration variables
|
|
num_vars = sum(len(s) for s in new_ranges)
|
|
flat_vars = sympy.symbols(f"v_0:{num_vars}")
|
|
count = 0
|
|
|
|
if len(iter_vars) == 0 and len(red_vars) == 0:
|
|
return {}
|
|
|
|
assert len(new_ranges) == len(norm_pw_vars + norm_red_vars)
|
|
apply_groups = []
|
|
for group in return_getters_groups:
|
|
apply_groups.append([g(flat_vars) for g in group])
|
|
|
|
iter_vars_to_flat_vars = {}
|
|
for i, (group, var_group) in enumerate(
|
|
zip(apply_groups, (iter_vars, red_vars), strict=True)
|
|
):
|
|
# if the node has sizes (p0, 1) and the fused node is (p0, r0)
|
|
# the reduction var gets filled in for split_iteration_range
|
|
if len(group) != len(var_group):
|
|
assert i == 1
|
|
assert len(var_group) == 0
|
|
continue
|
|
|
|
iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)})
|
|
|
|
count = 0
|
|
flat_vars_to_new_vars = {}
|
|
for new_range, new_var in zip(
|
|
new_ranges, norm_pw_vars + norm_red_vars, strict=True
|
|
):
|
|
range_vars = []
|
|
for i in range(len(new_range)):
|
|
range_vars.append(flat_vars[count])
|
|
count += 1
|
|
|
|
prod = 1
|
|
for i in range(len(new_range) - 1, -1, -1):
|
|
flat_vars_to_new_vars[range_vars[i]] = new_var * prod
|
|
prod = new_range[i] * prod
|
|
|
|
return {
|
|
k: sympy_subs(v, flat_vars_to_new_vars)
|
|
for k, v in iter_vars_to_flat_vars.items()
|
|
}
|
|
|
|
|
|
def extract_normalized_read_writes(
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> Optional[FusedNormalizedReadsWrites]:
|
|
"""Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node."""
|
|
reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
all_output_names = node.get_buffer_names()
|
|
op_names = node.get_operation_names()
|
|
outputs: OrderedSet[str] = OrderedSet()
|
|
removed_buffers: OrderedSet[str] = OrderedSet()
|
|
for buf_name in all_output_names:
|
|
if V.graph.scheduler.can_buffer_be_removed_through_fusion(buf_name, op_names):
|
|
removed_buffers.add(buf_name)
|
|
else:
|
|
outputs.add(buf_name)
|
|
|
|
inputs = OrderedSet(
|
|
dep.name for dep in node.read_writes.reads if dep.name not in removed_buffers
|
|
)
|
|
|
|
pointwise_numel: sympy.Expr = node.group[1][0]
|
|
red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
# TODO - a few dynamic shapes issues to resolve
|
|
if any(
|
|
(isinstance(var, sympy.Expr) and not var.is_constant())
|
|
for var in (pointwise_numel, red_numel)
|
|
):
|
|
return None
|
|
|
|
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
|
|
|
|
# lets use different prefix (`n`) to distinguish
|
|
(norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze(
|
|
pw_splits, red_splits, prefix="n"
|
|
)
|
|
node = node
|
|
|
|
for n in list(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
body = n._body
|
|
|
|
# TODO - not handled well. indirect loads will not be coalesced,
|
|
# need to account for that in analysis.
|
|
if body.indirect_vars:
|
|
return None
|
|
|
|
n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
# TODO - will the names for all the inputs/outputs accurately
|
|
# reflect mutation, or do I need to remap with mutation_real_name
|
|
for inp in inputs:
|
|
for expr in body.get_all_read_expr(inp):
|
|
n_reads[expr].add(inp)
|
|
|
|
for out in outputs:
|
|
for expr in body.get_all_write_expr(out):
|
|
n_writes[expr].add(out)
|
|
|
|
if not n_reads and not n_writes:
|
|
continue
|
|
|
|
(iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits(
|
|
n, pointwise_numel, red_numel
|
|
)
|
|
|
|
groups = pw_splits + red_splits
|
|
lengths = (n_pw_splits, (n_red_splits))
|
|
lengths = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
groups, lengths, red_numel
|
|
)
|
|
)
|
|
new_ranges, return_getters_groups = (
|
|
torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges(
|
|
groups, lengths
|
|
)
|
|
)
|
|
var_map = apply_var_mapping(
|
|
iter_vars,
|
|
red_vars,
|
|
norm_pw_vars,
|
|
norm_red_vars,
|
|
new_ranges,
|
|
return_getters_groups,
|
|
)
|
|
|
|
# We create Identity sympy.Functions to prevent expansion to int64,
|
|
# unwrap for tiling analysis.
|
|
def remove_identity(expr: sympy.Expr) -> sympy.Expr:
|
|
return expr.replace(Identity, lambda x: x)
|
|
|
|
n_reads_new = {
|
|
sympy_subs(remove_identity(read), var_map): v for read, v in n_reads.items()
|
|
}
|
|
n_writes_new = {
|
|
sympy_subs(remove_identity(write), var_map): v
|
|
for write, v in n_writes.items()
|
|
}
|
|
|
|
for expr, buf_names in n_reads_new.items():
|
|
reads[expr] |= buf_names
|
|
|
|
for expr, buf_names in n_writes_new.items():
|
|
writes[expr] |= buf_names
|
|
|
|
reads = {
|
|
V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items()
|
|
}
|
|
writes = {
|
|
V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items()
|
|
}
|
|
|
|
fused_out = FusedNormalizedReadsWrites(
|
|
norm_pw_vars, # type: ignore[arg-type]
|
|
norm_red_vars, # type: ignore[arg-type]
|
|
reads,
|
|
writes,
|
|
ranges,
|
|
)
|
|
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
|
|
return fused_out
|
|
|
|
|
|
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
|
"""
|
|
Score addr according to its approximate size
|
|
"""
|
|
|
|
# TODO - deduplicate with candidate_tilings
|
|
var_sizes = []
|
|
for v in addr.free_symbols:
|
|
v_size = var_ranges.get(v)
|
|
# TODO - reason about indirect vars
|
|
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
|
var_sizes.append(v_size)
|
|
from .virtualized import V
|
|
|
|
return V.graph.sizevars.atomically_apply_size_hint(
|
|
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
|
|
)
|
|
|
|
|
|
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
|
if isinstance(v, int):
|
|
return v
|
|
else:
|
|
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class VarTiling:
|
|
"""
|
|
Tiling of a var by `tiling_factor` that yields additional coalesced mem accesses by `benefit_score`
|
|
"""
|
|
|
|
var: sympy.Symbol
|
|
tiling_factor: int
|
|
score: int
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CoalesceVarAnalysis:
|
|
# Var -> Memory Score - not strictly the amount of memory
|
|
# because we multiply writes x2
|
|
# TODO: separate into dataclass that olds mem, dtype, is_write
|
|
coalesced_by_var: dict[sympy.Expr, int]
|
|
|
|
norm_read_writes: FusedNormalizedReadsWrites
|
|
|
|
suggested_split: Optional[VarTiling] = None
|
|
|
|
|
|
def analyze_memory_coalescing(
|
|
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> Optional[CoalesceVarAnalysis]:
|
|
"""
|
|
Find variables that coalesce the reads and writes and score the total size.
|
|
|
|
If uncoalesced memory expressions are found, look for additionally tiling of variables
|
|
which will coalesce memory accesses.
|
|
|
|
For instance - for the following expression:
|
|
|
|
(32*p0) // 2048
|
|
|
|
Tiling p0 by 64 will make this expression coalesced.
|
|
"""
|
|
|
|
norm_read_writes = extract_normalized_read_writes(fused_node)
|
|
|
|
if norm_read_writes is None:
|
|
return None
|
|
|
|
reads = norm_read_writes.reads
|
|
writes = norm_read_writes.writes
|
|
var_ranges = norm_read_writes.var_ranges
|
|
|
|
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
|
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
|
|
|
for is_read, (memory_expr, buf_names) in itertools.chain(
|
|
((True, item) for item in reads.items()),
|
|
((False, item) for item in writes.items()),
|
|
):
|
|
# skip memory deps with indirect vars - todo: better handling
|
|
indirect_expr = bool(
|
|
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
|
|
)
|
|
|
|
if indirect_expr:
|
|
continue
|
|
|
|
size = get_score(memory_expr, var_ranges)
|
|
if size == 0:
|
|
continue
|
|
|
|
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
|
|
|
byte_multipler = 0
|
|
for buf_name in buf_names:
|
|
if buf := V.graph.try_get_buffer(buf_name):
|
|
byte_multipler += buf.dtype.itemsize
|
|
|
|
# coalesced writes more important
|
|
byte_multipler *= 1 if is_read else 2
|
|
|
|
if maybe_coalesced_var:
|
|
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
|
else:
|
|
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
|
|
|
if not uncoalesced_addrs:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
# map from var -> tiling -> total_score
|
|
tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter)
|
|
|
|
for uncoalesced_expr, addr_score in uncoalesced_addrs.items():
|
|
expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0)
|
|
for v in uncoalesced_expr.free_symbols:
|
|
# skip non iter/reduce var variables
|
|
if v not in var_ranges:
|
|
continue
|
|
# skip small addrs
|
|
if addr_score == 0:
|
|
continue
|
|
del expr_subs[v]
|
|
single_var_expr = sympy_subs(uncoalesced_expr, expr_subs)
|
|
expr_subs[v] = 0
|
|
tiling_factor = solve_for_tiling(single_var_expr)
|
|
if (
|
|
tiling_factor is None
|
|
or not tiling_factor.is_constant()
|
|
or not tiling_factor.is_integer
|
|
):
|
|
continue
|
|
|
|
tiling_factor = int(tiling_factor)
|
|
if not V.graph.sizevars.statically_known_lt(tiling_factor, var_ranges[v]):
|
|
continue
|
|
|
|
# TODO - if a var is in the middle, such as [n0, n1, n2]
|
|
# n1 can can be split beyond range
|
|
|
|
MIN_TILING_BLOCK = 8
|
|
if not all(
|
|
V.graph.sizevars.statically_known_lt(MIN_TILING_BLOCK, block)
|
|
for block in (tiling_factor, var_ranges[v] // tiling_factor)
|
|
):
|
|
continue
|
|
|
|
tiling_scores[v][tiling_factor] += addr_score
|
|
|
|
if len(tiling_scores) == 0:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
best_tiling: Optional[tuple[sympy.Expr, int]] = None
|
|
best_tiling_score = 0
|
|
|
|
for var, tiling_counter in tiling_scores.items():
|
|
for tile, tile_score in tiling_counter.items():
|
|
if tile_score > best_tiling_score:
|
|
best_tiling = (var, tile)
|
|
best_tiling_score = tile_score
|
|
|
|
if best_tiling is None:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
# TODO - for strictly pointwise fusions,
|
|
# we can consider just swizzling the var if the var we are going to tile
|
|
# does not coalesce a significant portion of global reads
|
|
# TODO - could also prefer index var splits to reduction, better tested
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var,
|
|
norm_read_writes=norm_read_writes,
|
|
suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score),
|
|
)
|