import dataclasses import functools import itertools import sys from collections import Counter, defaultdict from collections.abc import Iterable, Iterator from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union import sympy import torch from torch._inductor import config from torch._inductor.dependencies import index_vars_no_squeeze from torch._inductor.utils import sympy_product, sympy_subs from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT from .virtualized import V T = TypeVar("T") U = TypeVar("U") Split = tuple[sympy.Expr, ...] loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling") if TYPE_CHECKING: from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode def find_coalesced_var( index: sympy.Expr, var_ranges: dict[sympy.Expr, int] ) -> Optional[sympy.Expr]: """ Try to find the symbol which coalesces this index """ top_level_terms = sympy.Add.make_args(index) for v in var_ranges: if v in top_level_terms: return v # Approximate analysis by evaluating at 1 and 0 variables: dict[sympy.Symbol, int] = {} for v in index.free_symbols: if v in var_ranges: variables[v] = 0 else: variables[v] = get_hint(v) zero_index = sympy_subs(index, variables) for v in var_ranges.keys(): variables[v] = 1 try: new_val = sympy_subs(index, variables) except ZeroDivisionError: loop_tiling_log.info("zero division error %s %s", index, variables) continue if new_val - zero_index == 1: return v variables[v] = 0 return None @dataclasses.dataclass(frozen=True) class FusedNormalizedReadsWrites: """ Normalized reads and writes for nodes in the same FusedSchedulerNode. """ index_vars: OrderedSet[sympy.Symbol] reduce_vars: OrderedSet[sympy.Symbol] reads: dict[sympy.Expr, OrderedSet[str]] writes: dict[sympy.Expr, OrderedSet[str]] var_ranges: dict[sympy.Symbol, int] def get_pw_red_splits( n: "SchedulerNode", pointwise_numel: sympy.Expr, red_numel: sympy.Expr ) -> tuple[tuple[list[sympy.Symbol], list[int]], tuple[list[sympy.Symbol], list[int]]]: if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel: return ( (n._body.iter_vars, n._body.sizes[0]), (n._body.reduce_vars, n._body.sizes[1]), ) # type: ignore[return-value] assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator] i = len(n._body.sizes[0]) - 1 prod = 1 while i >= 0: prod *= n._body.sizes[0][i] if prod == red_numel: break if i >= 0: pw_splits = n._body.sizes[0][0:i] iter_vars = n._body.iter_vars[0:i] red_splits = n._body.sizes[0][i:] red_vars = n._body.iter_vars[i:] return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value] # TODO - handle, not sure if possible raise RuntimeError( f"Unhandled node: size: {n._body.sizes}, pw: {pointwise_numel}, red: {red_numel}" ) class NodeSplitGetter: """ Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode. """ def __init__( self, node: Union["FusedSchedulerNode", "SchedulerNode"], ): self.node = node self.pointwise_numel: sympy.Expr = node.group[1][0] self.red_numel: sympy.Expr = node.group[1][1] self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet) self.reduction_split: Split = () self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet() fused_group = node.group[1] for n in reversed(node.get_nodes()): if not isinstance(n, torch._inductor.scheduler.SchedulerNode): continue (_, n_pw_splits), (_, n_red_splits) = get_pw_red_splits( n, self.pointwise_numel, self.red_numel ) # fill in reduction size n_pw_splits, n_red_splits = ( torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( fused_group, (n_pw_splits, n_red_splits), self.red_numel ) ) self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits)) # initially, we are just going to do a single reduction split since # reduction tiling is off by default. even if we miss a reduction split, # we can recover it in the split var analysis. # TODO: an earlier version fo this code tried to iteratively try the maximum number # of split vars, by iterating over both pointwise and reduction. but not worth # the complexity yet. if n_red_splits != (): self.reduction_split = (sympy_product(n_red_splits),) n_size = (tuple(n_pw_splits), tuple(n_red_splits)) self.all_node_sizes.add(n_size) self.seen_pw_splits: OrderedSet[Split] = OrderedSet() def get_node_splits(self) -> tuple[Split, Split]: """ Get a compatible pointwise, reduction split of the node """ if len(self.all_node_sizes) == 1: return next(iter(self.all_node_sizes)) max_pw_split = max(self.pw_split_options.keys()) for pw_split_len in range(max_pw_split, 0, -1): for pw_split in self.pw_split_options[pw_split_len]: if out := self.try_split(pw_split, self.reduction_split): return out # combine dims for next round for pw_split in self.pw_split_options[pw_split_len]: for i in range(len(pw_split) - 1): new_split = tuple( pw_split[0:i] + (sympy_product(pw_split[i : i + 2]),) + pw_split[i + 2 :] ) self.pw_split_options[len(new_split)].add(new_split) # if for whatever reason we couldnt split above, return default split return ((self.pointwise_numel,), (self.red_numel,)) def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]: """ See if this split is compatible, and potentially returning a longer split than the input. """ from torch._inductor.codegen.simd import CantSplit, SIMDKernel if pw in self.seen_pw_splits: return None self.seen_pw_splits.add(pw) for n_pw, n_red in self.all_node_sizes: try: groups = pw + red lengths = (n_pw, n_red) splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths) except CantSplit: return None assert len(getters) == 2 pw_group_splits = splits[: len(pw)] # if we had to divide a variable into two to do this split, # then lets try the larger, induced split. # e.g. splitting (12, 2) into (2, 12) will split the first var into: # (2, 6) and produce an overall split of (2, 6, 2) flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits)) if flattened_pw_splits != pw: if out := self.try_split(flattened_pw_splits, red): return out return pw, red if sys.version_info >= (3, 10): # On Python 3.10+ we can use zip(strict=True) zip_equal = functools.partial(zip, strict=True) else: # Fallback for older versions def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]: """ Zip two iterables, raising ValueError if their lengths differ. """ if len(it1) != len(it2): raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}") return zip(it1, it2) def apply_var_mapping( iter_vars: list[sympy.Symbol], red_vars: list[sympy.Symbol], norm_pw_vars: list[sympy.Symbol], norm_red_vars: list[sympy.Symbol], new_ranges: list[list[sympy.Expr]], return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]], ) -> dict[sympy.Symbol, sympy.Expr]: """Maps original variables to expressions using normalized variables.""" # the output of split_iteration_range is a new_ranges, return_getters_groups # new_ranges is a flattened list of ranges corresponding to the new pw and red vars # for example, taking in pw vars of range (6, 6) to normalized range [36], # new_ranges would be [[6, 6]] # There is a return_getter callable for each input iter_var and red_vars. # if you flatten out all of the ranges, and create a variable for each index, # then applying the flattening vars to the callables in return_getters_groups # gives you the mapping from input vars -> flattened vars. # From there, we can compute the output, normalized variables. # For instance [6, 6] corresponding to flat vars v0, v1 will be # v0 + 6 * v1 # Create flattened iteration variables num_vars = sum(len(s) for s in new_ranges) flat_vars = sympy.symbols(f"v_0:{num_vars}") count = 0 if len(iter_vars) == 0 and len(red_vars) == 0: return {} assert len(new_ranges) == len(norm_pw_vars + norm_red_vars) apply_groups = [] for group in return_getters_groups: apply_groups.append([g(flat_vars) for g in group]) iter_vars_to_flat_vars = {} for i, (group, var_group) in enumerate( zip_equal(apply_groups, ((iter_vars, red_vars))) ): # if the node has sizes (p0, 1) and the fused node is (p0, r0) # the reduction var gets filled in for split_iteration_range if len(group) != len(var_group): assert i == 1 assert len(var_group) == 0 continue iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)}) count = 0 flat_vars_to_new_vars = {} for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars): range_vars = [] for i in range(len(new_range)): range_vars.append(flat_vars[count]) count += 1 prod = 1 for i in range(len(new_range) - 1, -1, -1): flat_vars_to_new_vars[range_vars[i]] = new_var * prod prod = new_range[i] * prod return { k: sympy_subs(v, flat_vars_to_new_vars) for k, v in iter_vars_to_flat_vars.items() } def extract_normalized_read_writes( node: Union["FusedSchedulerNode", "SchedulerNode"], ) -> FusedNormalizedReadsWrites: """Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node.""" reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) all_output_names = node.get_buffer_names() op_names = node.get_operation_names() outputs = OrderedSet( buf for buf in all_output_names if not V.graph.scheduler.can_buffer_be_removed_through_fusion(buf, op_names) ) inputs = OrderedSet(dep.name for dep in node.read_writes.reads) pw_splits, red_splits = NodeSplitGetter(node).get_node_splits() # lets use different prefix (`n`) to distinguish (norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze( pw_splits, red_splits, prefix="n" ) node = node pointwise_numel: sympy.Expr = node.group[1][0] red_numel: sympy.Expr = node.group[1][1] for n in list(node.get_nodes()): if not isinstance(n, torch._inductor.scheduler.SchedulerNode): continue body = n._body n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) for inp in inputs: for expr in body.get_all_read_expr(inp): n_reads[expr].add(inp) for out in outputs: for expr in body.get_all_write_expr(out): n_writes[expr].add(out) if not n_reads and not n_writes: continue (iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits( n, pointwise_numel, red_numel ) groups = pw_splits + red_splits lengths = (n_pw_splits, (n_red_splits)) lengths = ( torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( groups, lengths, red_numel ) ) new_ranges, return_getters_groups = ( torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges( groups, lengths ) ) var_map = apply_var_mapping( iter_vars, red_vars, norm_pw_vars, norm_red_vars, new_ranges, return_getters_groups, ) n_reads_new = {sympy_subs(read, var_map): v for read, v in n_reads.items()} n_writes_new = {sympy_subs(write, var_map): v for write, v in n_writes.items()} for expr, buf_names in n_reads_new.items(): reads[expr] |= buf_names for expr, buf_names in n_writes_new.items(): writes[expr] |= buf_names reads = { V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items() } writes = { V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items() } fused_out = FusedNormalizedReadsWrites( norm_pw_vars, # type: ignore[arg-type] norm_red_vars, # type: ignore[arg-type] reads, writes, ranges, ) loop_tiling_log.info("Normalized Fused reads: %s", fused_out) return fused_out def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int: """ Score addr according to its approximate size """ # TODO - deduplicate with candidate_tilings var_sizes = [] for v in addr.free_symbols: v_size = var_ranges.get(v, None) # TODO - reason about indirect vars if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None: var_sizes.append(v_size) from .virtualized import V return V.graph.sizevars.atomically_apply_size_hint( sympy_product(var_sizes), fallback=config.unbacked_symint_fallback ) def get_hint(v: Union[sympy.Expr, int]) -> int: if isinstance(v, int): return v else: return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback) @dataclasses.dataclass(frozen=True) class CoalesceVarAnalysis: coalesced_by_var: dict[sympy.Expr, int] norm_read_writes: FusedNormalizedReadsWrites def analyze_memory_coalescing( fused_node: Union["FusedSchedulerNode", "SchedulerNode"], ) -> CoalesceVarAnalysis: """ Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additionally tiling of variables which will coalesce memory accesses. For instance - for the following expression: (32*p0) // 2048 Tiling p0 by 64 will make this expression coalesced. """ norm_read_writes = extract_normalized_read_writes(fused_node) reads = norm_read_writes.reads writes = norm_read_writes.writes var_ranges = norm_read_writes.var_ranges coalesced_by_var: dict[sympy.Symbol, int] = Counter() uncoalesced_addrs: dict[sympy.Expr, int] = Counter() for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()): size = get_score(memory_expr, var_ranges) # TODO - handle indirect if size == 0: continue # todo - dtype size maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges) byte_multipler = 0 for buf_name in buf_names: if buf := V.graph.try_get_buffer(buf_name): byte_multipler += buf.dtype.itemsize if maybe_coalesced_var: coalesced_by_var[maybe_coalesced_var] += size * byte_multipler else: uncoalesced_addrs[memory_expr] += size * byte_multipler return CoalesceVarAnalysis( coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes )