Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)

Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.

See, repro here: P1453035092.

Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
This commit is contained in:
eellison
2024-07-31 15:30:15 -07:00
committed by PyTorch MergeBot
parent bcd1d2e832
commit f32ab3b9e3
13 changed files with 371 additions and 313 deletions

View File

@ -1394,9 +1394,10 @@ class CompiledFxGraph:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
self.device_types = graph.device_types
self.device_idxs = graph.device_idxs
self.mutated_inputs = graph.mutated_inputs
# TODO - ordered set
self.device_types = set(graph.device_types)
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants

View File

@ -17,7 +17,6 @@ from typing import (
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
)
@ -29,6 +28,7 @@ import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
@ -1407,7 +1407,7 @@ class KernelArgs:
# after you do a call into this kernel, which buffers actually contain
# updated data? Modeled off of python_argdefs.
def live_output_buffers(self):
live_outs = set()
live_outs = OrderedSet() # type: ignore[var-annotated]
for inplaced in unique(self.inplace_buffers.values()):
if self._buffer_is_marked_removed(inplaced):
continue
@ -1483,10 +1483,10 @@ class CSE:
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()
self.invalidated_stores = set()
self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
self.varname_map = varname_map or {}
def invalidate(self, keep_vars: Set[str]):
def invalidate(self, keep_vars: OrderedSet[str]):
for name, tmp in list(self.store_cache.items()):
if tmp not in keep_vars:
del self.store_cache[name]
@ -1615,16 +1615,16 @@ class Kernel(CodeGen):
self.num_reduction = 0
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = set()
self.store_buffer_names = set()
self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
self._load_mask = None
self._load_other = None
# set in set_current_node
# OrderedSet in set_current_node
self.current_node = None
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
self.removed_buffers = set()
self.inplaced_to_remove = set()
self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for

View File

@ -15,7 +15,6 @@ from typing import (
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
@ -60,6 +59,8 @@ from .simd import constant_repr, SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from ..ops_handler import ReductionType, StoreMode
log = logging.getLogger(__name__)
@ -665,7 +666,7 @@ class HalideKernel(SIMDKernel):
self,
*groups,
index_dtype: str,
mutations: Optional[Set[str]] = None,
mutations: Optional[OrderedSet[str]] = None,
pid_cache=None,
reduction_hint=ReductionHint.DEFAULT,
override_persistent_reduction=None,
@ -717,7 +718,7 @@ class HalideKernel(SIMDKernel):
assert not (
self.index_replacements or self.halide_vars or self.reduction_renames
)
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf)
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
indices = dict.fromkeys(map(super().prepare_indexing, indices))
all_used_symbols = set()
sym_to_node = {
@ -1647,7 +1648,7 @@ class HalideScheduling(SIMDScheduling):
int32_type = "hl.Int(32)"
# TODO(jansel): Halide doesn't actually support 64 bit indexing...
int64_type = "hl.Int(64)"
kernel_type = HalideKernel
kernel_type = HalideKernel # type: ignore[arg-type]
@classmethod
def get_backend_features(cls, device: torch.device):

View File

@ -5,6 +5,7 @@ import pathlib
from typing import Any, List
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.utils._ordered_set import OrderedSet
from .. import config
from ..codecache import get_path, TritonFuture
@ -219,11 +220,11 @@ class MultiKernel:
@property
def removed_buffers(self):
return set.intersection(*[k.removed_buffers for k in self.kernels])
return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
@property
def inplaced_to_remove(self):
return set.intersection(*[k.inplaced_to_remove for k in self.kernels])
return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
@property
@cache_on_self

View File

@ -19,7 +19,6 @@ from typing import (
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
@ -28,6 +27,7 @@ import sympy
import torch
import torch._logging
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
@ -311,7 +311,7 @@ class SIMDKernel(Kernel):
self,
*groups,
index_dtype: str,
mutations: Optional[Set[str]] = None,
mutations: Optional[OrderedSet[str]] = None,
pid_cache=None,
reduction_hint=ReductionHint.DEFAULT,
override_persistent_reduction=None,
@ -322,14 +322,16 @@ class SIMDKernel(Kernel):
self.body = IndentedBuffer()
self.indexing_code = IndentedBuffer()
self.numels = [V.graph.sizevars.simplify(s) for s in groups]
self.mutations: Set[str] = mutations if mutations is not None else set()
self.mutations: OrderedSet[str] = (
mutations if mutations is not None else OrderedSet()
)
self.range_trees: List[IterationRangesRoot] = []
self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
self.iter_vars_count = itertools.count()
self.inside_reduction = self.numels[-1] != 1
self.reduction_hint = reduction_hint
self.index_dtype: str = index_dtype
self.last_usage: Set[str] = set()
self.last_usage: OrderedSet[str] = OrderedSet()
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
self.persistent_reduction: bool = (
override_persistent_reduction
@ -483,7 +485,7 @@ class SIMDKernel(Kernel):
def set_last_usage(self, nodes):
if not self.inside_reduction or self.persistent_reduction:
return
self.last_usage = set(
self.last_usage = OrderedSet(
itertools.chain.from_iterable(
n.last_usage for n in nodes if n is not EnableReduction
)
@ -832,7 +834,7 @@ class SIMDKernel(Kernel):
# This arg points to a buf that has been sliced.
# We need to count each individual slice to have
# a better estimation.
indices: Set[Any] = set()
indices: OrderedSet[Any] = OrderedSet()
no_index_dep_count = 0
for dep in self.buf_accesses[arg]:
if isinstance(dep, (StarDep, WeakDep)):
@ -1063,13 +1065,13 @@ class SIMDScheduling(BaseScheduling):
def generate_node_schedule(self, nodes, numel, rnumel):
node_schedule: List[Any] = []
current_loop_writes: Set[str] = set()
current_loop_writes: OrderedSet[str] = OrderedSet()
# Writes with a reduced shape, meaning they are only present once the
# reduction loop has ended
current_loop_reduced_writes = set()
current_loop_reduced_writes: OrderedSet[str] = OrderedSet()
current_loop_has_writes = False
done = set()
done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
def fits_in_main_body(n):
_, (node_numel, node_rnumel) = n.group
@ -1221,7 +1223,7 @@ class SIMDScheduling(BaseScheduling):
@classmethod
def select_index_dtype(cls, node_schedule, numel, reduction_numel):
# Gather all used buffer names
buffer_names = set()
buffer_names: OrderedSet[str] = OrderedSet()
for node in node_schedule:
if not isinstance(node, scheduler.BaseSchedulerNode):
continue
@ -1296,7 +1298,7 @@ class SIMDScheduling(BaseScheduling):
else:
reduction_hint_val = ReductionHint.DEFAULT
mutations = set()
mutations: OrderedSet[str] = OrderedSet()
for node in node_schedule:
if node in (DisableReduction, EnableReduction):
continue
@ -1653,7 +1655,7 @@ class SIMDScheduling(BaseScheduling):
break
return (numel, reduction_numel)
seen_names = set()
seen_names: OrderedSet[str] = OrderedSet()
candidate_tiles: Counter[Any] = collections.Counter()
for node in EnableReduction.filter(node_schedule):
for tiling in cls.candidate_tilings(node):
@ -1720,7 +1722,7 @@ class SIMDScheduling(BaseScheduling):
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
for n in nodes:
n.last_usage = set()
n.last_usage = OrderedSet()
if not nodes[0].is_template():
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group

View File

@ -16,7 +16,6 @@ from typing import (
Iterable,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
@ -29,6 +28,7 @@ import torch._logging
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._triton import has_triton_package
@ -136,7 +136,7 @@ block_sizes = {
@dataclasses.dataclass
class IndexingOptions:
index_str: str
mask_vars: Set[str]
mask_vars: OrderedSet[str]
mask_str: str
expand_str: Optional[str]
_has_rindex: bool
@ -163,7 +163,7 @@ class BlockPtrOptions:
params: BlockParameters
constant_offset: sympy.Expr
order: List[int]
mask_vars: Set[str]
mask_vars: OrderedSet[str]
reshape_suffix: List[str]
@property
@ -188,7 +188,7 @@ class BlockPtrOptions:
params: BlockParameters,
constant_offset: sympy.Expr,
range_trees: List[IterationRangesEntry],
mask_vars: Set[str],
mask_vars: OrderedSet[str],
) -> BlockPtrOptions:
"""Helper to create a BlockPtrOptions instance"""
reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
@ -454,7 +454,7 @@ class TritonPrinter(PythonPrinter):
# Use a macro so we can propagate constexprs.
# https://github.com/triton-lang/triton/issues/3815
a, b = tuple(f"({x})" for x in (a, b))
assert cmp in {">", "<"}, f"Unexpected comparator: '{cmp}'"
assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
def _print_Min(self, expr):
@ -575,7 +575,7 @@ class TritonCSEVariable(CSEVariable):
def __init__(self, name, bounds: ValueRanges[Any]):
super().__init__(name, bounds)
# We'll use this to track which masks the variable needs when used for indirect indexing
self.mask_vars: Set[str] = set()
self.mask_vars: OrderedSet[str] = OrderedSet()
def update_on_args(self, name, args, kwargs):
for arg in args:
@ -608,10 +608,10 @@ class TritonOverrides(OpOverrides):
# fp8 data type conversions has min_elem_per_thread requirements.
# Refer to Triton implementations here:
# https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.
fp8_dtypes = {
fp8_dtypes = (
torch.float8_e4m3fn,
torch.float8_e5m2,
}
)
# Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2.
assert not (
src_dtype in fp8_dtypes
@ -1037,7 +1037,7 @@ class TritonKernelOverrides(TritonOverrides):
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
)
if dtype not in {torch.int32, torch.int64}:
if dtype not in (torch.int32, torch.int64):
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
var.mask_vars = indexing.mask_vars
return var
@ -1176,7 +1176,7 @@ class TritonKernel(SIMDKernel):
self,
*groups,
index_dtype: str,
mutations: Optional[Set[str]] = None,
mutations: Optional[OrderedSet[str]] = None,
pid_cache=None,
reduction_hint=ReductionHint.DEFAULT,
min_elem_per_thread=0,
@ -1191,13 +1191,13 @@ class TritonKernel(SIMDKernel):
override_persistent_reduction=override_persistent_reduction,
)
self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment]
self.outside_loop_vars: Set[Any] = set()
self.outside_loop_vars: OrderedSet[Any] = OrderedSet()
self.min_elem_per_thread = min_elem_per_thread
self.block_ptr_id = itertools.count()
self.helper_functions = HelperFunctions()
# A set of autotuning hints to pass as part of triton_meta
self.autotune_hints: Set[AutotuneHint] = set()
self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet()
self.triton_meta: Optional[Dict[str, object]] = None
self.codegen_range_tree()
@ -1285,7 +1285,7 @@ class TritonKernel(SIMDKernel):
index_vars = index.free_symbols
has_rindex = False
mask_vars: Set[str] = set()
mask_vars: OrderedSet[str] = OrderedSet()
for var in index_vars:
assert isinstance(var, sympy.Symbol)
has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
@ -1322,7 +1322,7 @@ class TritonKernel(SIMDKernel):
have_dense = True
have_loop_vars = False
dense_mask_vars = set()
dense_mask_vars: OrderedSet[str] = OrderedSet()
for tree in self.active_range_trees():
if index_vars.intersection(tree.var_list):
@ -1584,7 +1584,7 @@ class TritonKernel(SIMDKernel):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
return IndexingOptions(
index_str, set(), "None", expand_str, has_rindex, index
index_str, OrderedSet(), "None", expand_str, has_rindex, index
)
if need_dense and not have_dense:
@ -1596,7 +1596,7 @@ class TritonKernel(SIMDKernel):
mask_vars = dense_mask_vars
if override_mask:
mask_vars = {override_mask}
mask_vars = OrderedSet([override_mask])
if self._load_mask:
mask_vars.add(self._load_mask)
@ -1721,9 +1721,11 @@ class TritonKernel(SIMDKernel):
ep = ", eviction_policy='evict_last'"
elif self.inside_reduction and self.range_trees[-1].is_loop:
if name in self.args.inplace_buffers:
names = set(self.args.inplace_buffers[name].other_names)
names: OrderedSet[str] = OrderedSet(
self.args.inplace_buffers[name].other_names
)
else:
names = {name}
names = OrderedSet([name])
last_use = len(names & self.last_usage) > 0
evict_last = not last_use and (has_rindex or indirect_indexing)
if evict_last:
@ -1881,7 +1883,7 @@ class TritonKernel(SIMDKernel):
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
assert self.inside_reduction
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
masks = sorted(masks)
if self._load_mask:
@ -1929,7 +1931,7 @@ class TritonKernel(SIMDKernel):
dim = self.triton_tensor_ndim() - 1
acc_type = triton_acc_type(src_dtype)
result_var: Any = self.cse.newvar()
result_var.mask_vars = {var for var in masks if var[0] != "r"}
result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r")
cond = " & ".join(masks)
def where_cond(tval, fval):
@ -2090,7 +2092,7 @@ class TritonKernel(SIMDKernel):
if isinstance(result_var, tuple):
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
self.outside_loop_vars |= set(result_var)
self.outside_loop_vars |= OrderedSet(result_var)
else:
assert isinstance(result_var, TritonCSEVariable)
self.outside_loop_vars.add(result_var)
@ -2172,7 +2174,7 @@ class TritonKernel(SIMDKernel):
values: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
assert self.inside_reduction
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.scan not supported inside ops.masked"
@ -2284,7 +2286,7 @@ class TritonKernel(SIMDKernel):
descending: bool,
) -> Tuple[CSEVariable, ...]:
assert self.inside_reduction
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.sort not supported inside ops.masked"
@ -2586,7 +2588,7 @@ class TritonKernel(SIMDKernel):
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
)
mutated_args = set()
mutated_args: OrderedSet[str] = OrderedSet()
for mutation in self.mutations:
if mutation in self.args.input_buffers:
mutated_args.add(self.args.input_buffers[mutation])

View File

@ -1,12 +1,13 @@
# mypy: allow-untyped-defs
import functools
from typing import Optional, Set
from typing import Optional
import torch._inductor.runtime.hints
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._prims_common import prod
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv
@ -30,7 +31,7 @@ class TritonSplitScanKernel(TritonKernel):
self,
*groups,
index_dtype: str,
mutations: Optional[Set[str]] = None,
mutations: Optional[OrderedSet[str]] = None,
reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT,
min_elem_per_thread=0,
):

View File

@ -6,13 +6,14 @@ import itertools
import logging
import re
import typing
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import sympy
import torch
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from .codegen.common import index_prevent_reordering
from .utils import (
@ -123,7 +124,7 @@ class MemoryDep(Dep):
if self.is_indirect():
numel = V.graph.get_numel(self.name)
else:
vars = set(self.index.free_symbols)
vars: OrderedSet[sympy.Expr] = OrderedSet(self.index.free_symbols)
numel = sympy.Integer(1)
for var, size in zip(self.var_names, self.size):
if var in vars:
@ -272,9 +273,9 @@ class IndexExprDep:
@dataclasses.dataclass
class ReadWrites:
reads: Set[Dep]
writes: Set[Dep]
index_exprs: Set[IndexExprDep]
reads: OrderedSet[Dep]
writes: OrderedSet[Dep]
index_exprs: OrderedSet[IndexExprDep]
range_vars: Optional[List[sympy.Expr]] = None
var_ranges: Optional[VarRanges] = None
op_counts: typing.Counter[str] = dataclasses.field(
@ -283,8 +284,8 @@ class ReadWrites:
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
return ReadWrites(
{dep.rename(renames) for dep in self.reads},
{dep.rename(renames) for dep in self.writes},
OrderedSet(dep.rename(renames) for dep in self.reads),
OrderedSet(dep.rename(renames) for dep in self.writes),
self.index_exprs,
self.range_vars,
self.var_ranges,
@ -294,7 +295,7 @@ class ReadWrites:
def with_read(self, dep: Dep) -> "ReadWrites":
assert isinstance(dep, (WeakDep, StarDep))
return ReadWrites(
set.union(self.reads, {dep}),
OrderedSet.union(self.reads, [dep]),
self.writes,
self.index_exprs,
self.range_vars,
@ -303,18 +304,18 @@ class ReadWrites:
)
def merge(self, other: "ReadWrites"):
reads = set.union(self.reads, other.reads)
writes = set.union(self.writes, other.writes)
index_exprs = set.union(self.index_exprs, other.index_exprs)
reads = OrderedSet.union(self.reads, other.reads)
writes = OrderedSet.union(self.writes, other.writes)
index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
op_counts = collections.Counter(self.op_counts)
op_counts.update(other.op_counts)
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
@staticmethod
def merge_list(read_writes: List["ReadWrites"]):
all_writes = set.union(*[rw.writes for rw in read_writes])
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
op_counts: typing.Counter[Any] = collections.Counter()
for rw in read_writes:
@ -339,7 +340,7 @@ class ReadWrites:
"""
Integer index is used for load_seed.
"""
names = set()
names: OrderedSet[str] = OrderedSet()
for dep in self.reads_and_writes():
if not isinstance(dep, MemoryDep):
continue
@ -353,9 +354,9 @@ class ReadWrites:
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
def __init__(self, var_ranges: VarRanges, normalize: bool):
super().__init__()
self._reads: Set[Dep] = set()
self._writes: Set[MemoryDep] = set()
self._index_exprs: Set[IndexExprDep] = set()
self._reads: OrderedSet[Dep] = OrderedSet()
self._writes: OrderedSet[MemoryDep] = OrderedSet()
self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
self._var_ranges: VarRanges = var_ranges
self._normalize: bool = normalize
@ -509,8 +510,8 @@ def extract_read_writes(
inner = rw.parent_handler.parent_handler
return ReadWrites(
set(inner._reads),
set(inner._writes),
OrderedSet(inner._reads),
OrderedSet(inner._writes),
inner._index_exprs,
range_vars,
var_ranges,
@ -550,7 +551,7 @@ def extract_input_node_reduction_ranges(
reduction_size = None
size = None
while reduction_size is None and len(reads) > 0:
seen = set()
seen: OrderedSet[str] = OrderedSet()
new_reads = []
for read in reads:
if not isinstance(read, MemoryDep):
@ -586,10 +587,10 @@ def canonicalization_prefix():
# ops handler which computes all the free unbacked symbols for an IR
class FreeUnbackedSymbolsOpsHandler:
symbols: Set[sympy.Symbol]
symbols: OrderedSet[sympy.Symbol]
def __init__(self):
self.symbols = set()
self.symbols = OrderedSet()
def __getattr__(self, name: str) -> Callable[..., Any]:
def inner(*args, **kwargs):

View File

@ -19,7 +19,6 @@ from typing import (
NoReturn,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
@ -51,6 +50,7 @@ from torch.fx.experimental.symbolic_shapes import (
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.utils._mode_utils import no_dispatch
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from . import config, ir
@ -345,35 +345,39 @@ class GraphLowering(torch.fx.Interpreter):
self.ras_by_symbol: Dict[
sympy.Symbol, List[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy()
self.bound_unbacked_symbols: Set[sympy.Symbol] = set()
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: List[str] = []
self.graph_inputs: Dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.device_types: Set[str] = (
const_module.device_types if const_module else set()
self.device_types: OrderedSet[str] = (
const_module.device_types if const_module else OrderedSet()
)
self.device_idxs: OrderedSet[int] = (
const_module.device_idxs if const_module else OrderedSet()
)
self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
self.cuda = False
self.buffers: List[ir.Buffer] = []
self.operations: List[ir.Operation] = []
self.const_output_index: Dict[str, int] = (
const_output_index if const_output_index else {}
)
self.folded_constants: Set[str] = (
set(const_output_index.keys()) if const_output_index else set()
self.folded_constants: OrderedSet[str] = (
OrderedSet(const_output_index.keys())
if const_output_index
else OrderedSet()
)
self.constants: Dict[str, torch.Tensor] = (
const_module.constants if const_module else {}
)
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
self.constant_reprs: Dict[str, str] = {}
self.removed_operations: Set[str] = set()
self.removed_buffers: Set[str] = set()
self.removed_inplace_buffers: Set[str] = set()
self.mutated_buffers: Set[str] = set()
self.never_reuse_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
self.removed_operations: OrderedSet[str] = OrderedSet()
self.removed_buffers: OrderedSet[str] = OrderedSet()
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
self.mutated_buffers: OrderedSet[str] = OrderedSet()
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
@ -389,7 +393,7 @@ class GraphLowering(torch.fx.Interpreter):
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.lists: Dict[str, List[str]] = {}
self.mutated_inputs: Set[str] = set()
self.mutated_inputs: OrderedSet[str] = OrderedSet()
self.mutated_input_idxs: List[int] = []
self.name_to_buffer: Dict[str, ir.Buffer] = {}
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
@ -400,7 +404,7 @@ class GraphLowering(torch.fx.Interpreter):
# record multi_kernel choice for cpp_wrapper so the second pass knows
# which sub-kernel is picked. Copy cpp_wrapper to another variable
# since cpp_wrapper flag is set to false for the first pass of codegen.
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
self.record_multi_kernel_choice = cpp_wrapper
self.multi_kernel_to_choice: Dict[str, int] = {}
@ -409,7 +413,7 @@ class GraphLowering(torch.fx.Interpreter):
self.post_grad_graph_id = next(_post_grad_graph_counter)
self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
self.nodes_prefer_channels_last = (
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
)
mark_nodes_dislike_padding(gm.graph)
self._warned_fallback = {"aten.convolution_backward"}
@ -439,8 +443,8 @@ class GraphLowering(torch.fx.Interpreter):
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
self.aligned_inputs: Set[str] = set()
self.no_fuse_buffer_names: Set[str] = set()
self.aligned_inputs: OrderedSet[str] = OrderedSet()
self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
def has_feature(
self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
@ -643,7 +647,7 @@ class GraphLowering(torch.fx.Interpreter):
name=self.qualify_name(subgraph_name),
)
def find_nodes_prefer_channels_last(self) -> Set[Node]:
def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
"""
The rule to decide if an node prefer channels last is simple.
1. if it's input/output of a convolution
@ -662,7 +666,7 @@ class GraphLowering(torch.fx.Interpreter):
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
can be saved.
"""
output_set = set()
output_set: OrderedSet[Node] = OrderedSet()
for n in reversed(self.module.graph.nodes):
if n.target == torch.ops.aten.convolution.default:
output_set.add(n)
@ -896,7 +900,7 @@ class GraphLowering(torch.fx.Interpreter):
if self.constants[name].device == device_override or device_override is None:
return name
with torch.utils._python_dispatch._disable_current_modes():
# caller might have set fake tensor mode which will create a fake tensor
# caller might have OrderedSet fake tensor mode which will create a fake tensor
# when calling .to, so unset modes here
return self.allocate_non_dup_const_name(
f"{name}_{device_override.type}{device_override.index or 0}",
@ -980,7 +984,7 @@ class GraphLowering(torch.fx.Interpreter):
# arguments' fx strides
layout_constraint = None
if torch._C.Tag.needs_fixed_stride_order in target.tags:
# We have to set the current args because call_function will immediately
# We have to OrderedSet the current args because call_function will immediately
# evaluate this lowering after creating the fallback, without evaluating
# the layout constraint
args, kwargs = constrain_to_fx_strides(
@ -1238,9 +1242,11 @@ class GraphLowering(torch.fx.Interpreter):
if n.op == "call_function":
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node(
with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
n
), V.set_current_node(n):
), V.set_current_node(
n
):
if (
n.op == "call_function"
and n.target is not operator.getitem
@ -1337,7 +1343,7 @@ class GraphLowering(torch.fx.Interpreter):
# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
num_users = len(set(n.users))
num_users = len(OrderedSet(n.users))
if num_users > 1 and isinstance(result, TensorBox):
for user in n.users:
if user.target in needs_realized_inputs:
@ -1446,7 +1452,7 @@ class GraphLowering(torch.fx.Interpreter):
self.register_users_of(result)
new_unbacked_defs = set()
new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet()
for buf in self.buffers[buffer_watermark:]:
new_unbacked_defs |= buf.get_unbacked_symbol_defs()
for op in self.operations[operation_watermark:]:
@ -1540,10 +1546,10 @@ class GraphLowering(torch.fx.Interpreter):
# end up needing to test equalities on the symbols, and a fresh
# symbol is likely to hit lots of GuardOnDataDependent errors that
# we already know facts for.
renamed_unbacked_bindings = {
renamed_unbacked_bindings = OrderedSet(
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
for s in unbacked_bindings.keys()
}
)
assert new_unbacked_defs >= renamed_unbacked_bindings, (
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
f"fx node is: {n.format_node()}\n"
@ -1611,9 +1617,9 @@ class GraphLowering(torch.fx.Interpreter):
if "cuda" in self.device_types:
# first pass
self.cpp_wrapper = False
# Although triton.store_cubin was set in compile_fx, the backward pass didn't pick
# Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick
# that up. In theory it should work by only setting triton.store_cubin to True here,
# but that will cause a problem when use_runtime_constant_folding is set.
# but that will cause a problem when use_runtime_constant_folding is OrderedSet.
with config.patch({"triton.store_cubin": True}):
compiled = self.compile_to_module().call
@ -1652,7 +1658,7 @@ class GraphLowering(torch.fx.Interpreter):
for x in itertools.chain(params_flat, V.real_inputs)
]
else:
# In the backward pass, V.real_inputs is not set.
# In the backward pass, V.real_inputs is not OrderedSet.
# Generating random inputs based on self.example_inputs sometimes can be problematic,
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
real_inputs = [

View File

@ -23,7 +23,6 @@ from typing import (
Optional,
overload,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
@ -62,6 +61,7 @@ from torch.fx.experimental.symbolic_shapes import (
resolve_unbacked_bindings,
SymTypes,
)
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import SymT
@ -325,11 +325,11 @@ def is_cpu(x: object) -> bool:
class IRNode:
_current_origins: ClassVar[Set[Any]] = set()
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
@staticmethod
@contextlib.contextmanager
def current_origins(origins: Set[torch.fx.Node]):
def current_origins(origins: OrderedSet[torch.fx.Node]):
old = IRNode._current_origins
IRNode._current_origins = old | origins
try:
@ -338,10 +338,10 @@ class IRNode:
IRNode._current_origins = old
def __post_init__(self):
self.origins = set(self._current_origins)
self.origins = OrderedSet(self._current_origins)
self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
def get_read_names(self) -> Set[str]:
def get_read_names(self) -> OrderedSet[str]:
raise NotImplementedError(f"NYI on {type(self)}")
def get_traceback(self):
@ -420,7 +420,7 @@ class IRNode:
make_indexer: Callable[[], Callable[[Any], Any]]
mark_reuse: Callable[[int], None]
realize_hint: Callable[[], None]
get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]]
get_unbacked_symbol_uses: Callable[[], OrderedSet[sympy.Symbol]]
@dataclasses.dataclass
@ -455,8 +455,8 @@ class Operation:
def is_user_of(self, name):
return name in self.get_read_names()
def get_read_names(self) -> Set[str]:
return {dep.name for dep in self.get_reads()}
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet(dep.name for dep in self.get_reads())
def get_reads(self):
return self.get_read_writes().reads
@ -464,10 +464,10 @@ class Operation:
def get_outputs(self) -> List[Buffer]:
raise NotImplementedError
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
"""
Returns the unbacked symbols which are required to be in scope in
order to successfully perform codegen for this buffer. For example,
@ -482,7 +482,7 @@ class Operation:
on that buffer, which will eventually have a dependency on i0 if
necessary.
"""
return set()
return OrderedSet()
def get_workspace_size(self):
"""
@ -499,8 +499,8 @@ class Loops(IRNode):
inner_fn: Callable[..., Any]
ranges: List[Expr]
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
return set().union(
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet().union(
*(free_unbacked_symbols(e) for e in self.ranges),
self.inner_fn_free_unbacked_symbols(),
)
@ -594,8 +594,8 @@ class Loops(IRNode):
self.get_size(),
).reads
def get_read_names(self) -> Set[str]:
return {dep.name for dep in self.get_reads()}
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet(dep.name for dep in self.get_reads())
def get_reduction_size(self):
raise NotImplementedError(
@ -689,7 +689,7 @@ def get_reduction_combine_fn(
if reduction_type in REDUCTION_COMBINE_FN:
return REDUCTION_COMBINE_FN[reduction_type]
elif reduction_type in {"argmax", "argmin"}:
elif reduction_type in ("argmax", "argmin"):
def argmax_combine_fn(
a: Tuple[object, object], b: Tuple[object, object]
@ -754,16 +754,16 @@ class Reduction(Loops):
src_dtype: torch.dtype
reduction_hint: ReductionHint
def __str__(self):
def __str__(self) -> str: # type: ignore[override]
return Loops.__str__( # type: ignore[call-arg]
self, names=("ranges", "reduction_ranges", "reduction_type")
)
def __repr__(self):
def __repr__(self) -> str: # type: ignore[override]
return self.__str__()
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
return super().get_unbacked_symbol_uses() | set().union(
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return super().get_unbacked_symbol_uses() | OrderedSet().union(
*(free_unbacked_symbols(e) for e in self.reduction_ranges)
)
@ -831,10 +831,10 @@ class Reduction(Loops):
should_split = (
not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
and reduction_type
not in {
not in (
"argmax",
"argmin",
}
)
and config.split_reductions
# We don't support unbacked symints
and _is_static(reduction_numel_hint)
@ -1221,14 +1221,14 @@ class Reduction(Loops):
@staticmethod
def default_accumulator(reduction_type, dtype):
if reduction_type in {"max", "argmax"}:
if reduction_type in ("max", "argmax"):
if is_float_dtype(dtype):
return float("-inf")
elif is_boolean_dtype(dtype):
return 0
else:
return torch.iinfo(dtype).min
if reduction_type in {"min", "argmin"}:
if reduction_type in ("min", "argmin"):
if is_float_dtype(dtype):
return float("inf")
elif is_boolean_dtype(dtype):
@ -1472,10 +1472,6 @@ class Reduction(Loops):
)
def num_reduction_outputs(reduction_type: Set[str]) -> int:
return 3 if "welford" in reduction_type else 1
class WelfordReduction(Reduction):
output_index: int
@ -1530,7 +1526,7 @@ class WelfordReduction(Reduction):
reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
):
assert reduction_type in {"welford_reduce", "welford_combine"}
assert reduction_type in ("welford_reduce", "welford_combine")
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
@ -1745,14 +1741,14 @@ class Scan(Loops):
# HACK we mimick reduction
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
# need to explicitly represent the closure so we can pull out unbacked
# symbols here
return (
super().get_unbacked_symbol_uses()
| set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
| set().union(*(free_unbacked_symbols(e) for e in self.size))
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size))
)
def __post_init__(self):
@ -1940,11 +1936,11 @@ class Sort(Loops):
# HACK we mimick reduction
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return (
super().get_unbacked_symbol_uses()
| set().union(*(free_unbacked_symbols(e) for e in self.sort_ranges))
| set().union(*(free_unbacked_symbols(e) for e in self.size))
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.sort_ranges))
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size))
)
def __post_init__(self):
@ -2211,7 +2207,7 @@ class BaseView(IRNode):
def is_module_buffer(self):
return self.data.is_module_buffer() # type: ignore[attr-defined]
def get_read_names(self) -> Set[str]:
def get_read_names(self) -> OrderedSet[str]:
return self.data.get_read_names()
def get_reads(self):
@ -2312,7 +2308,7 @@ class PermuteView(BaseView):
@classmethod
def create(cls, x, dims):
dims = cls._map_neg_dims(dims)
assert set(dims) == set(range(len(dims)))
assert OrderedSet(dims) == OrderedSet(range(len(dims)))
if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x)
@ -2332,14 +2328,16 @@ class PermuteView(BaseView):
return [dim if dim >= 0 else len(dims) + dim for dim in dims]
def get_size(self):
assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet(
range(len(self.dims))
)
size = self.data.get_size()
return [size[i] for i in self.dims]
def make_reindexer(self):
inv = {j: i for i, j in enumerate(self.dims)}
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
assert set(inv) == set(range(len(self.dims)))
assert OrderedSet(inv) == OrderedSet(range(len(self.dims)))
def reindex(index):
return [index[i] for i in inv]
@ -2420,7 +2418,7 @@ class GenericView(BaseView):
index_new = list(self.reindex(index_old))
return f"lambda {', '.join(map(str, index_old))}: {index_new}"
def __str__(self):
def __str__(self) -> str:
return self.str_helper(
[self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
)
@ -2594,7 +2592,7 @@ class ReinterpretView(BaseView):
if isinstance(self.data, BaseView):
self.data = self.data.unwrap_view()
def __str__(self):
def __str__(self) -> str:
return self.str_helper(
[
self.data,
@ -2643,7 +2641,7 @@ class ReinterpretView(BaseView):
def freeze_layout(self):
pass
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return (
free_unbacked_symbols(self.layout.size)
| free_unbacked_symbols(self.layout.stride)
@ -2684,7 +2682,7 @@ class DtypeView(BaseView):
return ReinterpretView(storage, new_layout)
return DtypeView(x, new_dtype)
def __str__(self):
def __str__(self) -> str:
return self.str_helper([self.data, self.target_dtype])
__repr__ = __str__
@ -2885,7 +2883,7 @@ class Layout(IRNode):
def stride(self):
return self._stride
def __str__(self):
def __str__(self) -> str:
offset = ""
if self.offset != 0:
offset = f", offset={self.offset}"
@ -3133,7 +3131,7 @@ class FlexibleLayout(Layout):
In this format, channels last would be:
[1, 3, 2, 0]
"""
assert set(range(len(sizes))) == set(order), (sizes, order)
assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order)
next_stride = sympy.Integer(1)
strides = [None] * len(order)
@ -3150,7 +3148,7 @@ class FlexibleLayout(Layout):
In this format, channels last would be:
[3, 0, 2, 1]
"""
assert set(range(len(sizes))) == set(order)
assert OrderedSet(range(len(sizes))) == OrderedSet(order)
fill_order = stride_order2fill_order(order)
return FlexibleLayout.fill_ordered(sizes, fill_order)
@ -3461,14 +3459,14 @@ class Buffer(IRNode):
return [self.layout.target.get_name()]
return ()
def get_read_names(self) -> Set[str]:
return {self.get_name()}
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet([self.get_name()])
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def realize(self):
pass
@ -3516,8 +3514,8 @@ class ConstantBuffer(InputBuffer):
class NoneAsConstantBuffer(IRNode):
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def codegen_reference(self, writer=None):
return V.graph.wrapper_code.none_str
@ -3532,7 +3530,7 @@ class ShapeAsConstantBuffer(IRNode):
def shape(self):
return self._shape
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return free_unbacked_symbols(self.shape)
def codegen_reference(self, writer=None):
@ -3572,7 +3570,7 @@ class ComputedBuffer(OperationBuffer):
self.data.get_size(),
)
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
# Ordinarily, we'd like to just peek at the arguments list,
# but ComputedBuffers have no argument list.
#
@ -3861,7 +3859,7 @@ class TemplateBuffer(OperationBuffer):
deps = dependencies.extract_read_writes(
dummy, self.get_size(), (), normalize=True
)
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs)
return deps
def get_reduction_size(self):
@ -3913,10 +3911,10 @@ class TritonTemplateBuffer(TemplateBuffer):
self.outputs: List[Buffer] = [self]
if mutated_inputs is not None:
# Ensure that the mutated inputs are only allowed for certain nodes
allowed_set = {
allowed_set = (
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.flex_attention_backward,
}
)
current_node = V.graph.current_node.target
assert (
current_node in allowed_set
@ -3929,7 +3927,7 @@ class TritonTemplateBuffer(TemplateBuffer):
def get_outputs(self) -> List[Buffer]:
return self.outputs
def __str__(self):
def __str__(self) -> str:
out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})"
return out
@ -4060,7 +4058,7 @@ class InputsKernel(OperationBuffer):
inputs: List[Buffer]
def get_read_writes(self):
reads: Set[dependencies.Dep] = set()
reads: OrderedSet[dependencies.Dep] = OrderedSet()
StarDep = dependencies.StarDep
for input in self.inputs:
if isinstance(input, list):
@ -4068,14 +4066,14 @@ class InputsKernel(OperationBuffer):
else:
reads.add(StarDep(input.get_name()))
writes: Set[dependencies.Dep] = {
writes: OrderedSet[dependencies.Dep] = OrderedSet(
StarDep(buf.get_name()) for buf in self.get_outputs()
}
)
return dependencies.ReadWrites(
reads=reads,
writes=writes,
index_exprs=set(),
index_exprs=OrderedSet(),
)
@classmethod
@ -4331,8 +4329,8 @@ class ExternKernel(InputsKernel):
def get_outputs(self) -> List[Buffer]:
return [self, *self.mutation_outputs]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def collect_arg_kwarg_properties(self):
# if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
@ -4371,7 +4369,7 @@ class ExternKernel(InputsKernel):
def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
# some of our shim interfaces require default values being OrderedSet.
# Discussed with Sherlock offline and we decided to allow serializing
# default args into the C++ wrapper code for now. We will refine this
# part if we see real FC requirement. More details related to FC
@ -4860,17 +4858,17 @@ class ExternKernel(InputsKernel):
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
return index, tuple(new_sizes)
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
# NB: It's not necessary to check regular inputs as we automatically
# have dependencies on them
r = set()
r: OrderedSet[sympy.Symbol] = OrderedSet()
for arg in self.constant_args:
r |= maybe_free_unbacked_symbols(arg)
for arg in self.kwargs.values():
r |= maybe_free_unbacked_symbols(arg)
return r
def __str__(self):
def __str__(self) -> str:
kernel_name = getattr(self, "python_kernel_name", None)
lines = [
f"python_kernel_name={kernel_name!r}",
@ -5056,13 +5054,13 @@ class UserDefinedTritonKernel(ExternKernel):
new_name, raw_args, self.grid, configs, triton_meta, kernel.constexprs
)
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
# add unbacked symbols used in the grid to the ones used
# in the kwargs (the latter is generated by ExternKernel)
return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid)
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(self, *, kernel_idx, grid, kernel_args):
inputs = []
@ -5144,8 +5142,8 @@ class InplaceBernoulliFallback(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(self, op_overload, x, *constant_args):
super().__init__(
@ -5182,8 +5180,8 @@ class InplaceCopyFallback(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(
self,
@ -5237,8 +5235,8 @@ class MutatingFirstArgExternKernel(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def has_side_effects(self):
return True
@ -5319,8 +5317,8 @@ class ScatterFallback(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(
self,
@ -5384,8 +5382,8 @@ class IndexPutFallback(ExternKernel):
def get_mutation_names(self):
return [self.inputs[0].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(self, op_overload, x, indices, values, accumulate):
self.indices = indices
@ -5457,8 +5455,8 @@ class DynamicScalar(ExternKernel):
self.sym = sym
self.keypath = keypath
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return {self.sym}
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.sym])
def codegen(self, wrapper):
wrapper.codegen_dynamic_scalar(self)
@ -5518,22 +5516,24 @@ class ExternKernelNode:
node: export_schema.Node
has_c_shim = {
aten._embedding_bag.default,
aten._fft_c2c.default,
aten._scaled_dot_product_efficient_attention.default,
aten._scaled_dot_product_flash_attention.default,
aten._scaled_dot_product_cudnn_attention.default,
aten._scaled_mm.default,
aten.addmm.out,
aten.bmm.out,
aten.copy_.default,
aten.mm.out,
aten.repeat_interleave.Tensor,
aten.nonzero.default,
aten.view.dtype,
aten.view_as_real.default,
}
has_c_shim = OrderedSet(
[
aten._embedding_bag.default,
aten._fft_c2c.default,
aten._scaled_dot_product_efficient_attention.default,
aten._scaled_dot_product_flash_attention.default,
aten._scaled_dot_product_cudnn_attention.default,
aten._scaled_mm.default,
aten.addmm.out,
aten.bmm.out,
aten.copy_.default,
aten.mm.out,
aten.repeat_interleave.Tensor,
aten.nonzero.default,
aten.view.dtype,
aten.view_as_real.default,
]
)
class FallbackKernel(ExternKernelAlloc):
@ -5719,13 +5719,13 @@ class FallbackKernel(ExternKernelAlloc):
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}"
)
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
return resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings
).keys()
else:
return set()
return OrderedSet()
def set_cpp_kernel(self, kernel):
from .codegen.wrapper import get_cpp_op_schema
@ -5758,7 +5758,7 @@ class FallbackKernel(ExternKernelAlloc):
class Shim:
ref: Any
def __repr__(self):
def __repr__(self) -> str:
return self.ref
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
@ -5784,7 +5784,9 @@ class FallbackKernel(ExternKernelAlloc):
if isinstance(example_output, torch.Tensor):
return example_output.device
if isinstance(example_output, (list, tuple)):
device_set = {FallbackKernel.find_device(None, x) for x in example_output}
device_set = OrderedSet(
FallbackKernel.find_device(None, x) for x in example_output
)
# Remove None
devices = [device for device in device_set if device]
if len(devices) == 1:
@ -6108,7 +6110,7 @@ class MultiOutput(ExternKernel):
V.graph.register_operation(self)
self.indices = indices
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return self.inputs[0].get_unbacked_symbol_uses()
def should_allocate(self):
@ -6140,10 +6142,10 @@ class MutableBox(IRNode):
def realize(self):
return self.data.realize()
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
return self.data.get_unbacked_symbol_uses()
def get_read_names(self) -> Set[str]:
def get_read_names(self) -> OrderedSet[str]:
return self.data.get_read_names()
def get_defining_op(self):
@ -6166,7 +6168,7 @@ class MutableBox(IRNode):
def dtype(self):
return self.data.dtype
def __str__(self):
def __str__(self) -> str:
if isinstance(self.data, MutableBox):
line0 = f"{type(self).__name__}({type(self.data).__name__}("
endl = "))"
@ -6326,7 +6328,7 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
for buffer in buffers
]
# assuming the same buffer is represented by the same IRNode object
return len({id(buffer) for buffer in buffers}) < len(buffers)
return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)
@dataclasses.dataclass
@ -7143,12 +7145,12 @@ class _WaitKernel(_CollectiveKernel):
# NB: recursive structure here reflects val_to_arg_str, avoid
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
# treatment
def maybe_free_unbacked_symbols(s: object) -> Set[Symbol]:
def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]:
if isinstance(s, (SymTypes, Expr)):
# This branch should be impossible in return position
return free_unbacked_symbols(s)
elif isinstance(s, (tuple, list)):
r = set()
r: OrderedSet[sympy.Symbol] = OrderedSet()
for t in s:
r |= maybe_free_unbacked_symbols(t)
return r
@ -7156,4 +7158,4 @@ def maybe_free_unbacked_symbols(s: object) -> Set[Symbol]:
# This branch is impossible in constant-args position
return free_unbacked_symbols(s)
else:
return set()
return OrderedSet()

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-defs
from typing import Any, List, Optional, Set
from typing import Any, List, Optional
import sympy
import torch
from torch._prims_common import make_channels_last_strides_for
from torch.utils._ordered_set import OrderedSet
from .ir import (
ExternKernelAlloc,
@ -437,8 +438,8 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
self.cpp_kernel_overload_name,
)
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
@classmethod
def create(
@ -862,8 +863,8 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
def get_mutation_names(self):
return [self.inputs[self.idx_for_inplace_sum].get_name()]
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
return set()
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
@classmethod
def create(

View File

@ -34,6 +34,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import free_symbol_is_type, SymT
from torch.utils._triton import has_triton
@ -153,13 +154,13 @@ class SchedulerBuffer:
class BaseSchedulerNode:
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites
unmet_dependencies: Set[Dep]
unmet_dependencies: OrderedSet[Dep]
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
self.scheduler: Scheduler = scheduler
self.node: Optional[ir.Operation] = node
self.set_read_writes(node.get_read_writes())
self.ancestors: Set[str] = set()
self.ancestors: OrderedSet[str] = OrderedSet()
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
@ -167,9 +168,9 @@ class BaseSchedulerNode:
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
self.min_order: int
self.max_order: int
self.last_usage: Set[
self.last_usage: OrderedSet[
str
] = set() # buffers that won't be used after this kernel
] = OrderedSet() # buffers that won't be used after this kernel
self.written = False
self.outputs: List[SchedulerBuffer] = [
@ -255,10 +256,10 @@ class BaseSchedulerNode:
self.prune_deps()
def set_last_usage(
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
) -> None:
used_buffers = self.used_or_aliased_buffer_names()
used_buffers = {mutation_real_name.get(k, k) for k in used_buffers}
used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers])
self.last_usage = used_buffers - future_used_buffers
def mark_run(self) -> None:
@ -268,14 +269,14 @@ class BaseSchedulerNode:
def op_counts(self) -> Counter[str]:
return self.read_writes.op_counts
def used_buffer_names(self) -> Set[str]:
return {
def used_buffer_names(self) -> OrderedSet[str]:
return OrderedSet(
dep.name
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
}
)
def used_or_aliased_buffer_names(self) -> Set[str]:
used_names = set()
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
used_names: OrderedSet[str] = OrderedSet()
deps = [
dep.name
@ -291,11 +292,11 @@ class BaseSchedulerNode:
return used_names
def prune_deps(self) -> None:
self.unmet_dependencies = {
self.unmet_dependencies = OrderedSet(
dep
for dep in self.unmet_dependencies
if dep.name not in self.scheduler.available_buffer_names
}
)
def prune_weak_deps(self) -> None:
# Prune weak dependencies on operations that have been removed
@ -305,7 +306,9 @@ class BaseSchedulerNode:
op = self.scheduler.name_to_buf[dep.name].defining_op
return op.get_name() in V.graph.removed_operations
to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)}
to_remove = OrderedSet(
dep for dep in self.read_writes.reads if should_prune(dep)
)
self.set_read_writes(self.read_writes.remove_reads(to_remove))
def prune_redundant_deps(
@ -320,11 +323,11 @@ class BaseSchedulerNode:
def get_first_name(self) -> str:
return self.get_name()
def get_operation_names(self) -> Set[str]:
return {node.get_name() for node in self.get_nodes()}
def get_operation_names(self) -> OrderedSet[str]:
return OrderedSet(node.get_name() for node in self.get_nodes())
def get_buffer_names(self) -> Set[str]:
return {out.get_name() for out in self.outputs}
def get_buffer_names(self) -> OrderedSet[str]:
return OrderedSet(out.get_name() for out in self.outputs)
def get_nodes(self) -> Sequence[BaseSchedulerNode]:
return [self]
@ -538,18 +541,18 @@ class BaseSchedulerNode:
for dep in self.read_writes.reads | self.read_writes.writes:
buf_accesses[dep.name].append(dep)
reads = {dep.name for dep in self.read_writes.reads}
writes = {dep.name for dep in self.read_writes.writes}
reads = OrderedSet(dep.name for dep in self.read_writes.reads)
writes = OrderedSet(dep.name for dep in self.read_writes.writes)
def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool:
users = self.scheduler.name_to_buf[buf].users
buf_uses = {user.node for user in users}
return len(buf_uses - set(snodes)) > 0
buf_uses = OrderedSet(user.node for user in users)
return len(buf_uses - OrderedSet(snodes)) > 0
if isinstance(self, FusedSchedulerNode):
removed_buffers = {
removed_buffers = OrderedSet(
dep for dep in writes if not is_materialized(dep, self.snodes)
}
)
writes = writes - removed_buffers
reads = reads - removed_buffers
node_bytes = 0
@ -714,7 +717,7 @@ class WhyNoFuse:
def pformat(obj: Any) -> str:
if isinstance(obj, set):
if isinstance(obj, OrderedSet):
# pformat has trouble with sets of sympy exprs
obj = sorted(obj, key=str)
result = pprint.pformat(obj, indent=4)
@ -725,7 +728,7 @@ def pformat(obj: Any) -> str:
class OutputNode:
def __init__(self, dep: StarDep) -> None:
self.unmet_dependencies = {dep}
self.unmet_dependencies = OrderedSet([dep])
def is_reduction(self) -> bool:
return False
@ -771,14 +774,16 @@ def _prune_redundant_deps(
else:
return False
deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)}
deps_to_prune = OrderedSet(
dep for dep in node.unmet_dependencies if should_prune(dep)
)
if deps_to_prune:
node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
kernel_name_to_op = {
"extern_kernels.convolution": torch.ops.aten.convolution,
"extern_kernels.mm": torch.ops.aten.mm,
@ -937,8 +942,8 @@ class SchedulerNode(BaseSchedulerNode):
return False
@cache_on_self
def _get_atomic_add_buffers(self) -> Set[str]:
buffers_store_as_atomic_add = set()
def _get_atomic_add_buffers(self) -> OrderedSet[str]:
buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
if isinstance(self._body, ir.LoopBody):
for node in self._body.get_nodes():
if (
@ -966,7 +971,7 @@ def init_group_node(
group_snode.snodes = snodes
group_snode.scheduler = scheduler
group_snode.node = None
group_snode.ancestors = set.union(
group_snode.ancestors = OrderedSet.union(
*[x.ancestors for x in snodes if x.ancestors is not None]
)
@ -974,11 +979,14 @@ def init_group_node(
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
)
group_snode.unmet_dependencies = {
dep
for dep in set.union(*[x.unmet_dependencies for x in snodes])
if dep.name not in group_snode.get_buffer_names()
} - group_snode.read_writes.writes
group_snode.unmet_dependencies = (
OrderedSet(
dep
for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes])
if dep.name not in group_snode.get_buffer_names()
)
- group_snode.read_writes.writes
)
group_snode.min_order = min(x.min_order for x in group_snode.snodes)
group_snode.max_order = max(x.max_order for x in group_snode.snodes)
@ -1020,8 +1028,8 @@ class FusedSchedulerNode(BaseSchedulerNode):
return self.snodes[0].get_name()
@cache_on_self
def get_buffer_names(self) -> Set[str]:
return set.union(*[x.get_buffer_names() for x in self.snodes])
def get_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
def get_outputs(self) -> List[SchedulerBuffer]:
result: List[SchedulerBuffer] = []
@ -1047,25 +1055,27 @@ class FusedSchedulerNode(BaseSchedulerNode):
return f"{self}, snodes: {snodes_str}"
def set_last_usage(
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
) -> None:
# Set self.last_usage using the global information
# This will be used for inter-kernel optimisations
super().set_last_usage(future_used_buffers, mutation_real_name)
# Set self.last_usage on the snodes
# This will be used for optimisations within the kernel
future_used_buffers: Set[str] = set()
future_used_buffers: OrderedSet[str] = OrderedSet()
for node in reversed(self.snodes):
node.set_last_usage(future_used_buffers, mutation_real_name)
future_used_buffers.update(node.last_usage)
@cache_on_self
def used_buffer_names(self) -> Set[str]:
return set.union(*[x.used_buffer_names() for x in self.snodes])
def used_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes])
@cache_on_self
def used_or_aliased_buffer_names(self) -> Set[str]:
return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes])
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(
*[x.used_or_aliased_buffer_names() for x in self.snodes]
)
def get_nodes(self) -> Sequence[BaseSchedulerNode]:
return self.snodes
@ -1300,13 +1310,16 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
)
)
self.unmet_dependencies = {
dep
for dep in set.union(
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
self.unmet_dependencies = (
OrderedSet(
dep
for dep in OrderedSet.union(
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
)
if dep.name not in self.get_buffer_names()
)
if dep.name not in self.get_buffer_names()
} - self.read_writes.writes
- self.read_writes.writes
)
self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
@ -1327,7 +1340,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
self.group = (nodes[0].get_device(), ((sympy.Expr("foreach"),),))
self.origins: Set[torch.fx.Node] = set()
self.origins: OrderedSet[torch.fx.Node] = OrderedSet()
def mark_run(self) -> None:
raise NotImplementedError
@ -1409,8 +1422,8 @@ class GroupedSchedulerNode(BaseSchedulerNode):
return self.snodes[0].get_name()
@cache_on_self
def get_buffer_names(self) -> Set[str]:
return set.union(*[x.get_buffer_names() for x in self.snodes])
def get_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
def get_outputs(self) -> List[SchedulerBuffer]:
result: List[SchedulerBuffer] = []
@ -1519,12 +1532,14 @@ class Scheduler:
self.backends: Dict[torch.device, BaseScheduling] = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
self.completed_operations: Set[str] = set()
self.available_buffer_names = {
*V.graph.graph_inputs.keys(),
*V.graph.constants.keys(),
*V.graph.torchbind_constants.keys(),
}
self.completed_operations: OrderedSet[str] = OrderedSet()
self.available_buffer_names = OrderedSet(
[
*V.graph.graph_inputs.keys(),
*V.graph.constants.keys(),
*V.graph.torchbind_constants.keys(),
]
)
self.nodes = [self.create_scheduler_node(n) for n in nodes]
@ -1575,7 +1590,7 @@ class Scheduler:
self.num_orig_nodes = len(self.nodes)
self.create_foreach_nodes()
self.nodes = self.topological_sort_schedule(self.nodes)
self.logged_slow_fusion: Set[Tuple[str, str]] = set()
self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet()
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
self.nodes = self.fuse_nodes(self.nodes)
@ -1590,7 +1605,7 @@ class Scheduler:
# used during codegen:
self.current_device: Optional[torch.device] = None
self.buffer_names_to_free: Set[str] = set()
self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
# fx graph node to the position it appears in the graph
# for debug attribution
@ -1637,7 +1652,7 @@ class Scheduler:
raise NotImplementedError(node)
def create_foreach_nodes(self) -> None:
removed_node_names = set()
removed_node_names: OrderedSet[str] = OrderedSet()
fe_nodes = []
kept_node_names = self.name_to_fused_node.keys()
@ -1678,7 +1693,7 @@ class Scheduler:
"""
This data structure behaves like a list except it makes sure the
elements remain unique.
Normally one could use a set/dict for this purpose however
Normally one could use a OrderedSet/dict for this purpose however
the list in question gets elements appended as it is being
iterated over which means that we need to keep the list
semantics.
@ -1687,10 +1702,10 @@ class Scheduler:
def __init__(
self,
items: Optional[List[T]] = None,
membership: Optional[Set[T]] = None,
membership: Optional[OrderedSet[T]] = None,
) -> None:
self.items = items or []
self.membership = membership or set()
self.membership = membership or OrderedSet()
def append(self, node_user: T) -> None:
if node_user in self.membership:
@ -1699,7 +1714,7 @@ class Scheduler:
self.membership.add(node_user)
def __add__(self, other: DedupList[T]) -> DedupList[T]:
new_membership = set.union(self.membership, other.membership)
new_membership = OrderedSet.union(self.membership, other.membership)
new_items = self.items + [
x for x in other.items if x not in self.membership
]
@ -1915,7 +1930,7 @@ class Scheduler:
"""
Ensure nodes is in topologically sorted order
"""
seen: Set[BaseSchedulerNode] = set()
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []
@ -1941,9 +1956,9 @@ class Scheduler:
Populate each node.ancestors
"""
# note self.nodes is topologically sorted
name_to_ancestors: Dict[str, Set[str]] = {}
name_to_ancestors: Dict[str, OrderedSet[str]] = {}
for node in self.nodes:
ancestors = set()
ancestors: OrderedSet[str] = OrderedSet()
for dep in node.unmet_dependencies:
dep_node_name = self.name_to_buf[dep.name].defining_op.get_name()
ancestors.add(dep_node_name)
@ -2229,7 +2244,7 @@ class Scheduler:
- self.can_fuse(): checks if a fusion is legal
- self.score_fusion(): assigns priority to a given fusion
"""
fused_nodes = set(nodes)
fused_nodes = OrderedSet(nodes)
if fusion_log.isEnabledFor(logging.DEBUG):
fusion_log.debug("fuse_nodes_once, candidates:")
for node in fused_nodes:
@ -2271,7 +2286,7 @@ class Scheduler:
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
"""
possible_fusions = []
seen = set()
seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet()
def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:
for node1_index, node1 in enumerate(nodes):
@ -2319,8 +2334,8 @@ class Scheduler:
Finds whether there's a path from node1 to node2 (or vice-versa)
caused indirectly by other fusions.
"""
visited = set()
# since we are just returning boolean here, use slightly faster, unordered set
visited: Set[FusedSchedulerNode] = set()
def found_path(node: BaseSchedulerNode) -> bool:
# only fused nodes can introduce new ancestors.
@ -2345,8 +2360,14 @@ class Scheduler:
)
return False
combined_names = node1.get_operation_names() | node2.get_operation_names()
combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names
# as above - use slightly faster, unordered set
combined_names = (
node1.get_operation_names()._dict.keys()
| node2.get_operation_names()._dict.keys()
)
combined_ancestors = (
node1.ancestors._dict.keys() | node2.ancestors._dict.keys()
) - combined_names
cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
if cycle:
WhyNoFuse(node1, node2)("will create cycle")
@ -2555,7 +2576,7 @@ class Scheduler:
"""
node1_buf_names = node1.get_buffer_names()
node1_op_names = node1.get_operation_names()
computed_deps = set()
computed_deps: OrderedSet[Dep] = OrderedSet()
why = WhyNoFuse(node1, node2)
for cd in node1.read_writes.writes:
@ -2569,7 +2590,9 @@ class Scheduler:
if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2):
computed_deps.add(dep)
remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps}
remaining_deps = OrderedSet(
dep.name for dep in node2.unmet_dependencies - computed_deps
)
if remaining_deps & node1_buf_names:
# MemoryDeps didn't match and read different locations of the same buffer.
# Examples here include:
@ -2695,6 +2718,23 @@ class Scheduler:
The first term in our fusion score that estimates number of saved
memory operations.
"""
node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes)
node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes)
# optimization: iter over smaller set
if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len):
if node1_dep_len > node2_dep_len:
tmp = node1
node1 = node2
node2 = tmp
deps = []
for dep in node1.read_writes.reads | node1.read_writes.writes:
if dep in node2.read_writes.reads or dep in node2.read_writes.writes:
deps.append(dep)
return sum(self.dep_size_hint(dep) for dep in deps)
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
node2.read_writes.reads | node2.read_writes.writes
)
@ -2746,7 +2786,7 @@ class Scheduler:
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
"""
future_used_buffers = set(V.graph.get_output_names())
future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
for node in reversed(self.nodes):
node.set_last_usage(future_used_buffers, self.mutation_real_name)
@ -2776,11 +2816,11 @@ class Scheduler:
same kernel can be removed.
"""
fused_node_names = {
fused_node_names = OrderedSet(
self.name_to_buf[buf].defining_op.get_name()
for buf in V.kernel.store_buffer_names
if buf in self.name_to_buf
}
)
names_to_remove = []
for out_buf in V.kernel.store_buffer_names:
if out_buf not in self.name_to_buf:
@ -2789,7 +2829,7 @@ class Scheduler:
continue
users = self.name_to_buf[out_buf].users
assert users is not None
users = {user.get_name() for user in users if not user.is_weak}
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
if users.issubset(fused_node_names):
names_to_remove.append(out_buf)

View File

@ -400,7 +400,7 @@ class TritonTemplateKernel(TritonKernel):
self.body.writeline(f"{output_name} = {out.value}")
body_val = self.body.getvalue()
self.cse.invalidate(set())
self.cse.invalidate(set()) # type: ignore[arg-type]
return body_val
def store_output(
@ -600,7 +600,7 @@ class TritonTemplate(KernelTemplate):
self.all_templates[name] = self
self.debug = debug
def generate(
def generate( # type: ignore[override]
self,
input_nodes,
layout,