mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Migrate Inductor scheduler, dependencies, ir, and codegen/common to use OrderedSet (#130004)
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail. See, repro here: P1453035092. Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
bcd1d2e832
commit
f32ab3b9e3
@ -1394,9 +1394,10 @@ class CompiledFxGraph:
|
||||
with open(graph.cache_path) as f:
|
||||
self.source_code = f.read()
|
||||
self.cache_linemap = graph.cache_linemap
|
||||
self.device_types = graph.device_types
|
||||
self.device_idxs = graph.device_idxs
|
||||
self.mutated_inputs = graph.mutated_inputs
|
||||
# TODO - ordered set
|
||||
self.device_types = set(graph.device_types)
|
||||
self.device_idxs = set(graph.device_idxs)
|
||||
self.mutated_inputs = set(graph.mutated_inputs)
|
||||
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
||||
self.constants = graph.constants
|
||||
self.torchbind_constants = graph.torchbind_constants
|
||||
|
@ -17,7 +17,6 @@ from typing import (
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
@ -29,6 +28,7 @@ import torch
|
||||
import torch.fx
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
@ -1407,7 +1407,7 @@ class KernelArgs:
|
||||
# after you do a call into this kernel, which buffers actually contain
|
||||
# updated data? Modeled off of python_argdefs.
|
||||
def live_output_buffers(self):
|
||||
live_outs = set()
|
||||
live_outs = OrderedSet() # type: ignore[var-annotated]
|
||||
for inplaced in unique(self.inplace_buffers.values()):
|
||||
if self._buffer_is_marked_removed(inplaced):
|
||||
continue
|
||||
@ -1483,10 +1483,10 @@ class CSE:
|
||||
self.store_cache = store_cache or {}
|
||||
self.reduction_cache = reduction_cache or {}
|
||||
self.iter_buffer_ids = iter_buffers or itertools.count()
|
||||
self.invalidated_stores = set()
|
||||
self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
|
||||
self.varname_map = varname_map or {}
|
||||
|
||||
def invalidate(self, keep_vars: Set[str]):
|
||||
def invalidate(self, keep_vars: OrderedSet[str]):
|
||||
for name, tmp in list(self.store_cache.items()):
|
||||
if tmp not in keep_vars:
|
||||
del self.store_cache[name]
|
||||
@ -1615,16 +1615,16 @@ class Kernel(CodeGen):
|
||||
self.num_reduction = 0
|
||||
|
||||
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
||||
self.must_keep_buffers = set()
|
||||
self.store_buffer_names = set()
|
||||
self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
|
||||
self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
|
||||
self._load_mask = None
|
||||
self._load_other = None
|
||||
# set in set_current_node
|
||||
# OrderedSet in set_current_node
|
||||
self.current_node = None
|
||||
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
||||
|
||||
self.removed_buffers = set()
|
||||
self.inplaced_to_remove = set()
|
||||
self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
|
||||
self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
|
||||
|
||||
# key: the buffer to write
|
||||
# value: the buffer to read and whose memory can be reused for
|
||||
|
@ -15,7 +15,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
@ -60,6 +59,8 @@ from .simd import constant_repr, SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..ops_handler import ReductionType, StoreMode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -665,7 +666,7 @@ class HalideKernel(SIMDKernel):
|
||||
self,
|
||||
*groups,
|
||||
index_dtype: str,
|
||||
mutations: Optional[Set[str]] = None,
|
||||
mutations: Optional[OrderedSet[str]] = None,
|
||||
pid_cache=None,
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
override_persistent_reduction=None,
|
||||
@ -717,7 +718,7 @@ class HalideKernel(SIMDKernel):
|
||||
assert not (
|
||||
self.index_replacements or self.halide_vars or self.reduction_renames
|
||||
)
|
||||
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf)
|
||||
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
|
||||
indices = dict.fromkeys(map(super().prepare_indexing, indices))
|
||||
all_used_symbols = set()
|
||||
sym_to_node = {
|
||||
@ -1647,7 +1648,7 @@ class HalideScheduling(SIMDScheduling):
|
||||
int32_type = "hl.Int(32)"
|
||||
# TODO(jansel): Halide doesn't actually support 64 bit indexing...
|
||||
int64_type = "hl.Int(64)"
|
||||
kernel_type = HalideKernel
|
||||
kernel_type = HalideKernel # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def get_backend_features(cls, device: torch.device):
|
||||
|
@ -5,6 +5,7 @@ import pathlib
|
||||
from typing import Any, List
|
||||
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config
|
||||
from ..codecache import get_path, TritonFuture
|
||||
@ -219,11 +220,11 @@ class MultiKernel:
|
||||
|
||||
@property
|
||||
def removed_buffers(self):
|
||||
return set.intersection(*[k.removed_buffers for k in self.kernels])
|
||||
return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
|
||||
|
||||
@property
|
||||
def inplaced_to_remove(self):
|
||||
return set.intersection(*[k.inplaced_to_remove for k in self.kernels])
|
||||
return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
|
||||
|
||||
@property
|
||||
@cache_on_self
|
||||
|
@ -19,7 +19,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
@ -28,6 +27,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
|
||||
@ -311,7 +311,7 @@ class SIMDKernel(Kernel):
|
||||
self,
|
||||
*groups,
|
||||
index_dtype: str,
|
||||
mutations: Optional[Set[str]] = None,
|
||||
mutations: Optional[OrderedSet[str]] = None,
|
||||
pid_cache=None,
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
override_persistent_reduction=None,
|
||||
@ -322,14 +322,16 @@ class SIMDKernel(Kernel):
|
||||
self.body = IndentedBuffer()
|
||||
self.indexing_code = IndentedBuffer()
|
||||
self.numels = [V.graph.sizevars.simplify(s) for s in groups]
|
||||
self.mutations: Set[str] = mutations if mutations is not None else set()
|
||||
self.mutations: OrderedSet[str] = (
|
||||
mutations if mutations is not None else OrderedSet()
|
||||
)
|
||||
self.range_trees: List[IterationRangesRoot] = []
|
||||
self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
|
||||
self.iter_vars_count = itertools.count()
|
||||
self.inside_reduction = self.numels[-1] != 1
|
||||
self.reduction_hint = reduction_hint
|
||||
self.index_dtype: str = index_dtype
|
||||
self.last_usage: Set[str] = set()
|
||||
self.last_usage: OrderedSet[str] = OrderedSet()
|
||||
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
|
||||
self.persistent_reduction: bool = (
|
||||
override_persistent_reduction
|
||||
@ -483,7 +485,7 @@ class SIMDKernel(Kernel):
|
||||
def set_last_usage(self, nodes):
|
||||
if not self.inside_reduction or self.persistent_reduction:
|
||||
return
|
||||
self.last_usage = set(
|
||||
self.last_usage = OrderedSet(
|
||||
itertools.chain.from_iterable(
|
||||
n.last_usage for n in nodes if n is not EnableReduction
|
||||
)
|
||||
@ -832,7 +834,7 @@ class SIMDKernel(Kernel):
|
||||
# This arg points to a buf that has been sliced.
|
||||
# We need to count each individual slice to have
|
||||
# a better estimation.
|
||||
indices: Set[Any] = set()
|
||||
indices: OrderedSet[Any] = OrderedSet()
|
||||
no_index_dep_count = 0
|
||||
for dep in self.buf_accesses[arg]:
|
||||
if isinstance(dep, (StarDep, WeakDep)):
|
||||
@ -1063,13 +1065,13 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
def generate_node_schedule(self, nodes, numel, rnumel):
|
||||
node_schedule: List[Any] = []
|
||||
current_loop_writes: Set[str] = set()
|
||||
current_loop_writes: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# Writes with a reduced shape, meaning they are only present once the
|
||||
# reduction loop has ended
|
||||
current_loop_reduced_writes = set()
|
||||
current_loop_reduced_writes: OrderedSet[str] = OrderedSet()
|
||||
current_loop_has_writes = False
|
||||
done = set()
|
||||
done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
|
||||
|
||||
def fits_in_main_body(n):
|
||||
_, (node_numel, node_rnumel) = n.group
|
||||
@ -1221,7 +1223,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
@classmethod
|
||||
def select_index_dtype(cls, node_schedule, numel, reduction_numel):
|
||||
# Gather all used buffer names
|
||||
buffer_names = set()
|
||||
buffer_names: OrderedSet[str] = OrderedSet()
|
||||
for node in node_schedule:
|
||||
if not isinstance(node, scheduler.BaseSchedulerNode):
|
||||
continue
|
||||
@ -1296,7 +1298,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
else:
|
||||
reduction_hint_val = ReductionHint.DEFAULT
|
||||
|
||||
mutations = set()
|
||||
mutations: OrderedSet[str] = OrderedSet()
|
||||
for node in node_schedule:
|
||||
if node in (DisableReduction, EnableReduction):
|
||||
continue
|
||||
@ -1653,7 +1655,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
break
|
||||
return (numel, reduction_numel)
|
||||
|
||||
seen_names = set()
|
||||
seen_names: OrderedSet[str] = OrderedSet()
|
||||
candidate_tiles: Counter[Any] = collections.Counter()
|
||||
for node in EnableReduction.filter(node_schedule):
|
||||
for tiling in cls.candidate_tilings(node):
|
||||
@ -1720,7 +1722,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
|
||||
for n in nodes:
|
||||
n.last_usage = set()
|
||||
n.last_usage = OrderedSet()
|
||||
|
||||
if not nodes[0].is_template():
|
||||
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
||||
|
@ -16,7 +16,6 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
@ -29,6 +28,7 @@ import torch._logging
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
@ -136,7 +136,7 @@ block_sizes = {
|
||||
@dataclasses.dataclass
|
||||
class IndexingOptions:
|
||||
index_str: str
|
||||
mask_vars: Set[str]
|
||||
mask_vars: OrderedSet[str]
|
||||
mask_str: str
|
||||
expand_str: Optional[str]
|
||||
_has_rindex: bool
|
||||
@ -163,7 +163,7 @@ class BlockPtrOptions:
|
||||
params: BlockParameters
|
||||
constant_offset: sympy.Expr
|
||||
order: List[int]
|
||||
mask_vars: Set[str]
|
||||
mask_vars: OrderedSet[str]
|
||||
reshape_suffix: List[str]
|
||||
|
||||
@property
|
||||
@ -188,7 +188,7 @@ class BlockPtrOptions:
|
||||
params: BlockParameters,
|
||||
constant_offset: sympy.Expr,
|
||||
range_trees: List[IterationRangesEntry],
|
||||
mask_vars: Set[str],
|
||||
mask_vars: OrderedSet[str],
|
||||
) -> BlockPtrOptions:
|
||||
"""Helper to create a BlockPtrOptions instance"""
|
||||
reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
|
||||
@ -454,7 +454,7 @@ class TritonPrinter(PythonPrinter):
|
||||
# Use a macro so we can propagate constexprs.
|
||||
# https://github.com/triton-lang/triton/issues/3815
|
||||
a, b = tuple(f"({x})" for x in (a, b))
|
||||
assert cmp in {">", "<"}, f"Unexpected comparator: '{cmp}'"
|
||||
assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
|
||||
return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
|
||||
|
||||
def _print_Min(self, expr):
|
||||
@ -575,7 +575,7 @@ class TritonCSEVariable(CSEVariable):
|
||||
def __init__(self, name, bounds: ValueRanges[Any]):
|
||||
super().__init__(name, bounds)
|
||||
# We'll use this to track which masks the variable needs when used for indirect indexing
|
||||
self.mask_vars: Set[str] = set()
|
||||
self.mask_vars: OrderedSet[str] = OrderedSet()
|
||||
|
||||
def update_on_args(self, name, args, kwargs):
|
||||
for arg in args:
|
||||
@ -608,10 +608,10 @@ class TritonOverrides(OpOverrides):
|
||||
# fp8 data type conversions has min_elem_per_thread requirements.
|
||||
# Refer to Triton implementations here:
|
||||
# https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.
|
||||
fp8_dtypes = {
|
||||
fp8_dtypes = (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
}
|
||||
)
|
||||
# Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2.
|
||||
assert not (
|
||||
src_dtype in fp8_dtypes
|
||||
@ -1037,7 +1037,7 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
|
||||
)
|
||||
|
||||
if dtype not in {torch.int32, torch.int64}:
|
||||
if dtype not in (torch.int32, torch.int64):
|
||||
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
|
||||
var.mask_vars = indexing.mask_vars
|
||||
return var
|
||||
@ -1176,7 +1176,7 @@ class TritonKernel(SIMDKernel):
|
||||
self,
|
||||
*groups,
|
||||
index_dtype: str,
|
||||
mutations: Optional[Set[str]] = None,
|
||||
mutations: Optional[OrderedSet[str]] = None,
|
||||
pid_cache=None,
|
||||
reduction_hint=ReductionHint.DEFAULT,
|
||||
min_elem_per_thread=0,
|
||||
@ -1191,13 +1191,13 @@ class TritonKernel(SIMDKernel):
|
||||
override_persistent_reduction=override_persistent_reduction,
|
||||
)
|
||||
self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment]
|
||||
self.outside_loop_vars: Set[Any] = set()
|
||||
self.outside_loop_vars: OrderedSet[Any] = OrderedSet()
|
||||
self.min_elem_per_thread = min_elem_per_thread
|
||||
self.block_ptr_id = itertools.count()
|
||||
self.helper_functions = HelperFunctions()
|
||||
|
||||
# A set of autotuning hints to pass as part of triton_meta
|
||||
self.autotune_hints: Set[AutotuneHint] = set()
|
||||
self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet()
|
||||
self.triton_meta: Optional[Dict[str, object]] = None
|
||||
|
||||
self.codegen_range_tree()
|
||||
@ -1285,7 +1285,7 @@ class TritonKernel(SIMDKernel):
|
||||
index_vars = index.free_symbols
|
||||
has_rindex = False
|
||||
|
||||
mask_vars: Set[str] = set()
|
||||
mask_vars: OrderedSet[str] = OrderedSet()
|
||||
for var in index_vars:
|
||||
assert isinstance(var, sympy.Symbol)
|
||||
has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
|
||||
@ -1322,7 +1322,7 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
have_dense = True
|
||||
have_loop_vars = False
|
||||
dense_mask_vars = set()
|
||||
dense_mask_vars: OrderedSet[str] = OrderedSet()
|
||||
|
||||
for tree in self.active_range_trees():
|
||||
if index_vars.intersection(tree.var_list):
|
||||
@ -1584,7 +1584,7 @@ class TritonKernel(SIMDKernel):
|
||||
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
|
||||
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
|
||||
return IndexingOptions(
|
||||
index_str, set(), "None", expand_str, has_rindex, index
|
||||
index_str, OrderedSet(), "None", expand_str, has_rindex, index
|
||||
)
|
||||
|
||||
if need_dense and not have_dense:
|
||||
@ -1596,7 +1596,7 @@ class TritonKernel(SIMDKernel):
|
||||
mask_vars = dense_mask_vars
|
||||
|
||||
if override_mask:
|
||||
mask_vars = {override_mask}
|
||||
mask_vars = OrderedSet([override_mask])
|
||||
|
||||
if self._load_mask:
|
||||
mask_vars.add(self._load_mask)
|
||||
@ -1721,9 +1721,11 @@ class TritonKernel(SIMDKernel):
|
||||
ep = ", eviction_policy='evict_last'"
|
||||
elif self.inside_reduction and self.range_trees[-1].is_loop:
|
||||
if name in self.args.inplace_buffers:
|
||||
names = set(self.args.inplace_buffers[name].other_names)
|
||||
names: OrderedSet[str] = OrderedSet(
|
||||
self.args.inplace_buffers[name].other_names
|
||||
)
|
||||
else:
|
||||
names = {name}
|
||||
names = OrderedSet([name])
|
||||
last_use = len(names & self.last_usage) > 0
|
||||
evict_last = not last_use and (has_rindex or indirect_indexing)
|
||||
if evict_last:
|
||||
@ -1881,7 +1883,7 @@ class TritonKernel(SIMDKernel):
|
||||
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
||||
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
||||
assert self.inside_reduction
|
||||
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
|
||||
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
|
||||
self.filter_masks(masks)
|
||||
masks = sorted(masks)
|
||||
if self._load_mask:
|
||||
@ -1929,7 +1931,7 @@ class TritonKernel(SIMDKernel):
|
||||
dim = self.triton_tensor_ndim() - 1
|
||||
acc_type = triton_acc_type(src_dtype)
|
||||
result_var: Any = self.cse.newvar()
|
||||
result_var.mask_vars = {var for var in masks if var[0] != "r"}
|
||||
result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r")
|
||||
cond = " & ".join(masks)
|
||||
|
||||
def where_cond(tval, fval):
|
||||
@ -2090,7 +2092,7 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
if isinstance(result_var, tuple):
|
||||
assert all(isinstance(x, TritonCSEVariable) for x in result_var)
|
||||
self.outside_loop_vars |= set(result_var)
|
||||
self.outside_loop_vars |= OrderedSet(result_var)
|
||||
else:
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
self.outside_loop_vars.add(result_var)
|
||||
@ -2172,7 +2174,7 @@ class TritonKernel(SIMDKernel):
|
||||
values: Tuple[CSEVariable, ...],
|
||||
) -> Tuple[CSEVariable, ...]:
|
||||
assert self.inside_reduction
|
||||
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
|
||||
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
|
||||
self.filter_masks(masks)
|
||||
masks = sorted(masks)
|
||||
assert not self._load_mask, "ops.scan not supported inside ops.masked"
|
||||
@ -2284,7 +2286,7 @@ class TritonKernel(SIMDKernel):
|
||||
descending: bool,
|
||||
) -> Tuple[CSEVariable, ...]:
|
||||
assert self.inside_reduction
|
||||
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
|
||||
masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
|
||||
self.filter_masks(masks)
|
||||
masks = sorted(masks)
|
||||
assert not self._load_mask, "ops.sort not supported inside ops.masked"
|
||||
@ -2586,7 +2588,7 @@ class TritonKernel(SIMDKernel):
|
||||
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
|
||||
)
|
||||
|
||||
mutated_args = set()
|
||||
mutated_args: OrderedSet[str] = OrderedSet()
|
||||
for mutation in self.mutations:
|
||||
if mutation in self.args.input_buffers:
|
||||
mutated_args.add(self.args.input_buffers[mutation])
|
||||
|
@ -1,12 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
import torch._inductor.runtime.hints
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
||||
from torch._prims_common import prod
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
|
||||
|
||||
@ -30,7 +31,7 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
self,
|
||||
*groups,
|
||||
index_dtype: str,
|
||||
mutations: Optional[Set[str]] = None,
|
||||
mutations: Optional[OrderedSet[str]] = None,
|
||||
reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT,
|
||||
min_elem_per_thread=0,
|
||||
):
|
||||
|
@ -6,13 +6,14 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .utils import (
|
||||
@ -123,7 +124,7 @@ class MemoryDep(Dep):
|
||||
if self.is_indirect():
|
||||
numel = V.graph.get_numel(self.name)
|
||||
else:
|
||||
vars = set(self.index.free_symbols)
|
||||
vars: OrderedSet[sympy.Expr] = OrderedSet(self.index.free_symbols)
|
||||
numel = sympy.Integer(1)
|
||||
for var, size in zip(self.var_names, self.size):
|
||||
if var in vars:
|
||||
@ -272,9 +273,9 @@ class IndexExprDep:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReadWrites:
|
||||
reads: Set[Dep]
|
||||
writes: Set[Dep]
|
||||
index_exprs: Set[IndexExprDep]
|
||||
reads: OrderedSet[Dep]
|
||||
writes: OrderedSet[Dep]
|
||||
index_exprs: OrderedSet[IndexExprDep]
|
||||
range_vars: Optional[List[sympy.Expr]] = None
|
||||
var_ranges: Optional[VarRanges] = None
|
||||
op_counts: typing.Counter[str] = dataclasses.field(
|
||||
@ -283,8 +284,8 @@ class ReadWrites:
|
||||
|
||||
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
||||
return ReadWrites(
|
||||
{dep.rename(renames) for dep in self.reads},
|
||||
{dep.rename(renames) for dep in self.writes},
|
||||
OrderedSet(dep.rename(renames) for dep in self.reads),
|
||||
OrderedSet(dep.rename(renames) for dep in self.writes),
|
||||
self.index_exprs,
|
||||
self.range_vars,
|
||||
self.var_ranges,
|
||||
@ -294,7 +295,7 @@ class ReadWrites:
|
||||
def with_read(self, dep: Dep) -> "ReadWrites":
|
||||
assert isinstance(dep, (WeakDep, StarDep))
|
||||
return ReadWrites(
|
||||
set.union(self.reads, {dep}),
|
||||
OrderedSet.union(self.reads, [dep]),
|
||||
self.writes,
|
||||
self.index_exprs,
|
||||
self.range_vars,
|
||||
@ -303,18 +304,18 @@ class ReadWrites:
|
||||
)
|
||||
|
||||
def merge(self, other: "ReadWrites"):
|
||||
reads = set.union(self.reads, other.reads)
|
||||
writes = set.union(self.writes, other.writes)
|
||||
index_exprs = set.union(self.index_exprs, other.index_exprs)
|
||||
reads = OrderedSet.union(self.reads, other.reads)
|
||||
writes = OrderedSet.union(self.writes, other.writes)
|
||||
index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
|
||||
op_counts = collections.Counter(self.op_counts)
|
||||
op_counts.update(other.op_counts)
|
||||
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
|
||||
|
||||
@staticmethod
|
||||
def merge_list(read_writes: List["ReadWrites"]):
|
||||
all_writes = set.union(*[rw.writes for rw in read_writes])
|
||||
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
|
||||
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
|
||||
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
|
||||
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
|
||||
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
|
||||
|
||||
op_counts: typing.Counter[Any] = collections.Counter()
|
||||
for rw in read_writes:
|
||||
@ -339,7 +340,7 @@ class ReadWrites:
|
||||
"""
|
||||
Integer index is used for load_seed.
|
||||
"""
|
||||
names = set()
|
||||
names: OrderedSet[str] = OrderedSet()
|
||||
for dep in self.reads_and_writes():
|
||||
if not isinstance(dep, MemoryDep):
|
||||
continue
|
||||
@ -353,9 +354,9 @@ class ReadWrites:
|
||||
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
|
||||
def __init__(self, var_ranges: VarRanges, normalize: bool):
|
||||
super().__init__()
|
||||
self._reads: Set[Dep] = set()
|
||||
self._writes: Set[MemoryDep] = set()
|
||||
self._index_exprs: Set[IndexExprDep] = set()
|
||||
self._reads: OrderedSet[Dep] = OrderedSet()
|
||||
self._writes: OrderedSet[MemoryDep] = OrderedSet()
|
||||
self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
|
||||
self._var_ranges: VarRanges = var_ranges
|
||||
self._normalize: bool = normalize
|
||||
|
||||
@ -509,8 +510,8 @@ def extract_read_writes(
|
||||
|
||||
inner = rw.parent_handler.parent_handler
|
||||
return ReadWrites(
|
||||
set(inner._reads),
|
||||
set(inner._writes),
|
||||
OrderedSet(inner._reads),
|
||||
OrderedSet(inner._writes),
|
||||
inner._index_exprs,
|
||||
range_vars,
|
||||
var_ranges,
|
||||
@ -550,7 +551,7 @@ def extract_input_node_reduction_ranges(
|
||||
reduction_size = None
|
||||
size = None
|
||||
while reduction_size is None and len(reads) > 0:
|
||||
seen = set()
|
||||
seen: OrderedSet[str] = OrderedSet()
|
||||
new_reads = []
|
||||
for read in reads:
|
||||
if not isinstance(read, MemoryDep):
|
||||
@ -586,10 +587,10 @@ def canonicalization_prefix():
|
||||
|
||||
# ops handler which computes all the free unbacked symbols for an IR
|
||||
class FreeUnbackedSymbolsOpsHandler:
|
||||
symbols: Set[sympy.Symbol]
|
||||
symbols: OrderedSet[sympy.Symbol]
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = set()
|
||||
self.symbols = OrderedSet()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def inner(*args, **kwargs):
|
||||
|
@ -19,7 +19,6 @@ from typing import (
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
@ -51,6 +50,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
from . import config, ir
|
||||
@ -345,35 +345,39 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.ras_by_symbol: Dict[
|
||||
sympy.Symbol, List[RuntimeAssert]
|
||||
] = shape_env.deferred_runtime_asserts.copy()
|
||||
self.bound_unbacked_symbols: Set[sympy.Symbol] = set()
|
||||
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
self.graph_input_names: List[str] = []
|
||||
self.graph_inputs: Dict[str, TensorBox] = {}
|
||||
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
||||
self.device_types: Set[str] = (
|
||||
const_module.device_types if const_module else set()
|
||||
self.device_types: OrderedSet[str] = (
|
||||
const_module.device_types if const_module else OrderedSet()
|
||||
)
|
||||
self.device_idxs: OrderedSet[int] = (
|
||||
const_module.device_idxs if const_module else OrderedSet()
|
||||
)
|
||||
self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
|
||||
self.cuda = False
|
||||
self.buffers: List[ir.Buffer] = []
|
||||
self.operations: List[ir.Operation] = []
|
||||
self.const_output_index: Dict[str, int] = (
|
||||
const_output_index if const_output_index else {}
|
||||
)
|
||||
self.folded_constants: Set[str] = (
|
||||
set(const_output_index.keys()) if const_output_index else set()
|
||||
self.folded_constants: OrderedSet[str] = (
|
||||
OrderedSet(const_output_index.keys())
|
||||
if const_output_index
|
||||
else OrderedSet()
|
||||
)
|
||||
self.constants: Dict[str, torch.Tensor] = (
|
||||
const_module.constants if const_module else {}
|
||||
)
|
||||
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
|
||||
self.constant_reprs: Dict[str, str] = {}
|
||||
self.removed_operations: Set[str] = set()
|
||||
self.removed_buffers: Set[str] = set()
|
||||
self.removed_inplace_buffers: Set[str] = set()
|
||||
self.mutated_buffers: Set[str] = set()
|
||||
self.never_reuse_buffers: Set[str] = set()
|
||||
self.inplaced_to_remove: Set[str] = set()
|
||||
self.removed_operations: OrderedSet[str] = OrderedSet()
|
||||
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.mutated_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
@ -389,7 +393,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||
self.lists: Dict[str, List[str]] = {}
|
||||
self.mutated_inputs: Set[str] = set()
|
||||
self.mutated_inputs: OrderedSet[str] = OrderedSet()
|
||||
self.mutated_input_idxs: List[int] = []
|
||||
self.name_to_buffer: Dict[str, ir.Buffer] = {}
|
||||
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
||||
@ -400,7 +404,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
||||
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
||||
# since cpp_wrapper flag is set to false for the first pass of codegen.
|
||||
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
|
||||
self.record_multi_kernel_choice = cpp_wrapper
|
||||
self.multi_kernel_to_choice: Dict[str, int] = {}
|
||||
|
||||
@ -409,7 +413,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
|
||||
self.nodes_prefer_channels_last = (
|
||||
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
|
||||
self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
|
||||
)
|
||||
mark_nodes_dislike_padding(gm.graph)
|
||||
self._warned_fallback = {"aten.convolution_backward"}
|
||||
@ -439,8 +443,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
||||
|
||||
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
|
||||
self.aligned_inputs: Set[str] = set()
|
||||
self.no_fuse_buffer_names: Set[str] = set()
|
||||
self.aligned_inputs: OrderedSet[str] = OrderedSet()
|
||||
self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
def has_feature(
|
||||
self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature
|
||||
@ -643,7 +647,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
name=self.qualify_name(subgraph_name),
|
||||
)
|
||||
|
||||
def find_nodes_prefer_channels_last(self) -> Set[Node]:
|
||||
def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
|
||||
"""
|
||||
The rule to decide if an node prefer channels last is simple.
|
||||
1. if it's input/output of a convolution
|
||||
@ -662,7 +666,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
|
||||
can be saved.
|
||||
"""
|
||||
output_set = set()
|
||||
output_set: OrderedSet[Node] = OrderedSet()
|
||||
for n in reversed(self.module.graph.nodes):
|
||||
if n.target == torch.ops.aten.convolution.default:
|
||||
output_set.add(n)
|
||||
@ -896,7 +900,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if self.constants[name].device == device_override or device_override is None:
|
||||
return name
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
# caller might have set fake tensor mode which will create a fake tensor
|
||||
# caller might have OrderedSet fake tensor mode which will create a fake tensor
|
||||
# when calling .to, so unset modes here
|
||||
return self.allocate_non_dup_const_name(
|
||||
f"{name}_{device_override.type}{device_override.index or 0}",
|
||||
@ -980,7 +984,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# arguments' fx strides
|
||||
layout_constraint = None
|
||||
if torch._C.Tag.needs_fixed_stride_order in target.tags:
|
||||
# We have to set the current args because call_function will immediately
|
||||
# We have to OrderedSet the current args because call_function will immediately
|
||||
# evaluate this lowering after creating the fallback, without evaluating
|
||||
# the layout constraint
|
||||
args, kwargs = constrain_to_fx_strides(
|
||||
@ -1238,9 +1242,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if n.op == "call_function":
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
origins |= gather_origins(args, kwargs)
|
||||
with ir.IRNode.current_origins(origins), self.set_current_node(
|
||||
with ir.IRNode.current_origins(origins), self.set_current_node( # type: ignore[arg-type]
|
||||
n
|
||||
), V.set_current_node(n):
|
||||
), V.set_current_node(
|
||||
n
|
||||
):
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target is not operator.getitem
|
||||
@ -1337,7 +1343,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
# Realize if (1) any user need inputs realized, or (2) there is
|
||||
# already too many reads and rematerializing can be bad.
|
||||
num_users = len(set(n.users))
|
||||
num_users = len(OrderedSet(n.users))
|
||||
if num_users > 1 and isinstance(result, TensorBox):
|
||||
for user in n.users:
|
||||
if user.target in needs_realized_inputs:
|
||||
@ -1446,7 +1452,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
self.register_users_of(result)
|
||||
|
||||
new_unbacked_defs = set()
|
||||
new_unbacked_defs: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
for buf in self.buffers[buffer_watermark:]:
|
||||
new_unbacked_defs |= buf.get_unbacked_symbol_defs()
|
||||
for op in self.operations[operation_watermark:]:
|
||||
@ -1540,10 +1546,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# end up needing to test equalities on the symbols, and a fresh
|
||||
# symbol is likely to hit lots of GuardOnDataDependent errors that
|
||||
# we already know facts for.
|
||||
renamed_unbacked_bindings = {
|
||||
renamed_unbacked_bindings = OrderedSet(
|
||||
V.fake_mode.shape_env.unbacked_renamings.get(s, s)
|
||||
for s in unbacked_bindings.keys()
|
||||
}
|
||||
)
|
||||
assert new_unbacked_defs >= renamed_unbacked_bindings, (
|
||||
f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
|
||||
f"fx node is: {n.format_node()}\n"
|
||||
@ -1611,9 +1617,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if "cuda" in self.device_types:
|
||||
# first pass
|
||||
self.cpp_wrapper = False
|
||||
# Although triton.store_cubin was set in compile_fx, the backward pass didn't pick
|
||||
# Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick
|
||||
# that up. In theory it should work by only setting triton.store_cubin to True here,
|
||||
# but that will cause a problem when use_runtime_constant_folding is set.
|
||||
# but that will cause a problem when use_runtime_constant_folding is OrderedSet.
|
||||
with config.patch({"triton.store_cubin": True}):
|
||||
compiled = self.compile_to_module().call
|
||||
|
||||
@ -1652,7 +1658,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
for x in itertools.chain(params_flat, V.real_inputs)
|
||||
]
|
||||
else:
|
||||
# In the backward pass, V.real_inputs is not set.
|
||||
# In the backward pass, V.real_inputs is not OrderedSet.
|
||||
# Generating random inputs based on self.example_inputs sometimes can be problematic,
|
||||
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
|
||||
real_inputs = [
|
||||
|
@ -23,7 +23,6 @@ from typing import (
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
@ -62,6 +61,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
resolve_unbacked_bindings,
|
||||
SymTypes,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
@ -325,11 +325,11 @@ def is_cpu(x: object) -> bool:
|
||||
|
||||
|
||||
class IRNode:
|
||||
_current_origins: ClassVar[Set[Any]] = set()
|
||||
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def current_origins(origins: Set[torch.fx.Node]):
|
||||
def current_origins(origins: OrderedSet[torch.fx.Node]):
|
||||
old = IRNode._current_origins
|
||||
IRNode._current_origins = old | origins
|
||||
try:
|
||||
@ -338,10 +338,10 @@ class IRNode:
|
||||
IRNode._current_origins = old
|
||||
|
||||
def __post_init__(self):
|
||||
self.origins = set(self._current_origins)
|
||||
self.origins = OrderedSet(self._current_origins)
|
||||
self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
raise NotImplementedError(f"NYI on {type(self)}")
|
||||
|
||||
def get_traceback(self):
|
||||
@ -420,7 +420,7 @@ class IRNode:
|
||||
make_indexer: Callable[[], Callable[[Any], Any]]
|
||||
mark_reuse: Callable[[int], None]
|
||||
realize_hint: Callable[[], None]
|
||||
get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]]
|
||||
get_unbacked_symbol_uses: Callable[[], OrderedSet[sympy.Symbol]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -455,8 +455,8 @@ class Operation:
|
||||
def is_user_of(self, name):
|
||||
return name in self.get_read_names()
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
return {dep.name for dep in self.get_reads()}
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(dep.name for dep in self.get_reads())
|
||||
|
||||
def get_reads(self):
|
||||
return self.get_read_writes().reads
|
||||
@ -464,10 +464,10 @@ class Operation:
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
"""
|
||||
Returns the unbacked symbols which are required to be in scope in
|
||||
order to successfully perform codegen for this buffer. For example,
|
||||
@ -482,7 +482,7 @@ class Operation:
|
||||
on that buffer, which will eventually have a dependency on i0 if
|
||||
necessary.
|
||||
"""
|
||||
return set()
|
||||
return OrderedSet()
|
||||
|
||||
def get_workspace_size(self):
|
||||
"""
|
||||
@ -499,8 +499,8 @@ class Loops(IRNode):
|
||||
inner_fn: Callable[..., Any]
|
||||
ranges: List[Expr]
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
return set().union(
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet().union(
|
||||
*(free_unbacked_symbols(e) for e in self.ranges),
|
||||
self.inner_fn_free_unbacked_symbols(),
|
||||
)
|
||||
@ -594,8 +594,8 @@ class Loops(IRNode):
|
||||
self.get_size(),
|
||||
).reads
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
return {dep.name for dep in self.get_reads()}
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(dep.name for dep in self.get_reads())
|
||||
|
||||
def get_reduction_size(self):
|
||||
raise NotImplementedError(
|
||||
@ -689,7 +689,7 @@ def get_reduction_combine_fn(
|
||||
if reduction_type in REDUCTION_COMBINE_FN:
|
||||
return REDUCTION_COMBINE_FN[reduction_type]
|
||||
|
||||
elif reduction_type in {"argmax", "argmin"}:
|
||||
elif reduction_type in ("argmax", "argmin"):
|
||||
|
||||
def argmax_combine_fn(
|
||||
a: Tuple[object, object], b: Tuple[object, object]
|
||||
@ -754,16 +754,16 @@ class Reduction(Loops):
|
||||
src_dtype: torch.dtype
|
||||
reduction_hint: ReductionHint
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str: # type: ignore[override]
|
||||
return Loops.__str__( # type: ignore[call-arg]
|
||||
self, names=("ranges", "reduction_ranges", "reduction_type")
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
return self.__str__()
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
return super().get_unbacked_symbol_uses() | set().union(
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return super().get_unbacked_symbol_uses() | OrderedSet().union(
|
||||
*(free_unbacked_symbols(e) for e in self.reduction_ranges)
|
||||
)
|
||||
|
||||
@ -831,10 +831,10 @@ class Reduction(Loops):
|
||||
should_split = (
|
||||
not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
|
||||
and reduction_type
|
||||
not in {
|
||||
not in (
|
||||
"argmax",
|
||||
"argmin",
|
||||
}
|
||||
)
|
||||
and config.split_reductions
|
||||
# We don't support unbacked symints
|
||||
and _is_static(reduction_numel_hint)
|
||||
@ -1221,14 +1221,14 @@ class Reduction(Loops):
|
||||
|
||||
@staticmethod
|
||||
def default_accumulator(reduction_type, dtype):
|
||||
if reduction_type in {"max", "argmax"}:
|
||||
if reduction_type in ("max", "argmax"):
|
||||
if is_float_dtype(dtype):
|
||||
return float("-inf")
|
||||
elif is_boolean_dtype(dtype):
|
||||
return 0
|
||||
else:
|
||||
return torch.iinfo(dtype).min
|
||||
if reduction_type in {"min", "argmin"}:
|
||||
if reduction_type in ("min", "argmin"):
|
||||
if is_float_dtype(dtype):
|
||||
return float("inf")
|
||||
elif is_boolean_dtype(dtype):
|
||||
@ -1472,10 +1472,6 @@ class Reduction(Loops):
|
||||
)
|
||||
|
||||
|
||||
def num_reduction_outputs(reduction_type: Set[str]) -> int:
|
||||
return 3 if "welford" in reduction_type else 1
|
||||
|
||||
|
||||
class WelfordReduction(Reduction):
|
||||
output_index: int
|
||||
|
||||
@ -1530,7 +1526,7 @@ class WelfordReduction(Reduction):
|
||||
reduction_type: str,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
):
|
||||
assert reduction_type in {"welford_reduce", "welford_combine"}
|
||||
assert reduction_type in ("welford_reduce", "welford_combine")
|
||||
|
||||
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
|
||||
|
||||
@ -1745,14 +1741,14 @@ class Scan(Loops):
|
||||
|
||||
# HACK we mimick reduction
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
|
||||
# need to explicitly represent the closure so we can pull out unbacked
|
||||
# symbols here
|
||||
return (
|
||||
super().get_unbacked_symbol_uses()
|
||||
| set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
|
||||
| set().union(*(free_unbacked_symbols(e) for e in self.size))
|
||||
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
|
||||
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size))
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -1940,11 +1936,11 @@ class Sort(Loops):
|
||||
|
||||
# HACK we mimick reduction
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return (
|
||||
super().get_unbacked_symbol_uses()
|
||||
| set().union(*(free_unbacked_symbols(e) for e in self.sort_ranges))
|
||||
| set().union(*(free_unbacked_symbols(e) for e in self.size))
|
||||
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.sort_ranges))
|
||||
| OrderedSet().union(*(free_unbacked_symbols(e) for e in self.size))
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -2211,7 +2207,7 @@ class BaseView(IRNode):
|
||||
def is_module_buffer(self):
|
||||
return self.data.is_module_buffer() # type: ignore[attr-defined]
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return self.data.get_read_names()
|
||||
|
||||
def get_reads(self):
|
||||
@ -2312,7 +2308,7 @@ class PermuteView(BaseView):
|
||||
@classmethod
|
||||
def create(cls, x, dims):
|
||||
dims = cls._map_neg_dims(dims)
|
||||
assert set(dims) == set(range(len(dims)))
|
||||
assert OrderedSet(dims) == OrderedSet(range(len(dims)))
|
||||
|
||||
if is_storage_and_layout(x):
|
||||
storage, old_layout = as_storage_and_layout(x)
|
||||
@ -2332,14 +2328,16 @@ class PermuteView(BaseView):
|
||||
return [dim if dim >= 0 else len(dims) + dim for dim in dims]
|
||||
|
||||
def get_size(self):
|
||||
assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
|
||||
assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet(
|
||||
range(len(self.dims))
|
||||
)
|
||||
size = self.data.get_size()
|
||||
return [size[i] for i in self.dims]
|
||||
|
||||
def make_reindexer(self):
|
||||
inv = {j: i for i, j in enumerate(self.dims)}
|
||||
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
|
||||
assert set(inv) == set(range(len(self.dims)))
|
||||
assert OrderedSet(inv) == OrderedSet(range(len(self.dims)))
|
||||
|
||||
def reindex(index):
|
||||
return [index[i] for i in inv]
|
||||
@ -2420,7 +2418,7 @@ class GenericView(BaseView):
|
||||
index_new = list(self.reindex(index_old))
|
||||
return f"lambda {', '.join(map(str, index_old))}: {index_new}"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.str_helper(
|
||||
[self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
|
||||
)
|
||||
@ -2594,7 +2592,7 @@ class ReinterpretView(BaseView):
|
||||
if isinstance(self.data, BaseView):
|
||||
self.data = self.data.unwrap_view()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.str_helper(
|
||||
[
|
||||
self.data,
|
||||
@ -2643,7 +2641,7 @@ class ReinterpretView(BaseView):
|
||||
def freeze_layout(self):
|
||||
pass
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return (
|
||||
free_unbacked_symbols(self.layout.size)
|
||||
| free_unbacked_symbols(self.layout.stride)
|
||||
@ -2684,7 +2682,7 @@ class DtypeView(BaseView):
|
||||
return ReinterpretView(storage, new_layout)
|
||||
return DtypeView(x, new_dtype)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.str_helper([self.data, self.target_dtype])
|
||||
|
||||
__repr__ = __str__
|
||||
@ -2885,7 +2883,7 @@ class Layout(IRNode):
|
||||
def stride(self):
|
||||
return self._stride
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
offset = ""
|
||||
if self.offset != 0:
|
||||
offset = f", offset={self.offset}"
|
||||
@ -3133,7 +3131,7 @@ class FlexibleLayout(Layout):
|
||||
In this format, channels last would be:
|
||||
[1, 3, 2, 0]
|
||||
"""
|
||||
assert set(range(len(sizes))) == set(order), (sizes, order)
|
||||
assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order)
|
||||
next_stride = sympy.Integer(1)
|
||||
strides = [None] * len(order)
|
||||
|
||||
@ -3150,7 +3148,7 @@ class FlexibleLayout(Layout):
|
||||
In this format, channels last would be:
|
||||
[3, 0, 2, 1]
|
||||
"""
|
||||
assert set(range(len(sizes))) == set(order)
|
||||
assert OrderedSet(range(len(sizes))) == OrderedSet(order)
|
||||
fill_order = stride_order2fill_order(order)
|
||||
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
||||
|
||||
@ -3461,14 +3459,14 @@ class Buffer(IRNode):
|
||||
return [self.layout.target.get_name()]
|
||||
return ()
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
return {self.get_name()}
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet([self.get_name()])
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def realize(self):
|
||||
pass
|
||||
@ -3516,8 +3514,8 @@ class ConstantBuffer(InputBuffer):
|
||||
|
||||
|
||||
class NoneAsConstantBuffer(IRNode):
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def codegen_reference(self, writer=None):
|
||||
return V.graph.wrapper_code.none_str
|
||||
@ -3532,7 +3530,7 @@ class ShapeAsConstantBuffer(IRNode):
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return free_unbacked_symbols(self.shape)
|
||||
|
||||
def codegen_reference(self, writer=None):
|
||||
@ -3572,7 +3570,7 @@ class ComputedBuffer(OperationBuffer):
|
||||
self.data.get_size(),
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
# Ordinarily, we'd like to just peek at the arguments list,
|
||||
# but ComputedBuffers have no argument list.
|
||||
#
|
||||
@ -3861,7 +3859,7 @@ class TemplateBuffer(OperationBuffer):
|
||||
deps = dependencies.extract_read_writes(
|
||||
dummy, self.get_size(), (), normalize=True
|
||||
)
|
||||
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
|
||||
deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs)
|
||||
return deps
|
||||
|
||||
def get_reduction_size(self):
|
||||
@ -3913,10 +3911,10 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
self.outputs: List[Buffer] = [self]
|
||||
if mutated_inputs is not None:
|
||||
# Ensure that the mutated inputs are only allowed for certain nodes
|
||||
allowed_set = {
|
||||
allowed_set = (
|
||||
torch.ops.higher_order.flex_attention,
|
||||
torch.ops.higher_order.flex_attention_backward,
|
||||
}
|
||||
)
|
||||
current_node = V.graph.current_node.target
|
||||
assert (
|
||||
current_node in allowed_set
|
||||
@ -3929,7 +3927,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
return self.outputs
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})"
|
||||
return out
|
||||
|
||||
@ -4060,7 +4058,7 @@ class InputsKernel(OperationBuffer):
|
||||
inputs: List[Buffer]
|
||||
|
||||
def get_read_writes(self):
|
||||
reads: Set[dependencies.Dep] = set()
|
||||
reads: OrderedSet[dependencies.Dep] = OrderedSet()
|
||||
StarDep = dependencies.StarDep
|
||||
for input in self.inputs:
|
||||
if isinstance(input, list):
|
||||
@ -4068,14 +4066,14 @@ class InputsKernel(OperationBuffer):
|
||||
else:
|
||||
reads.add(StarDep(input.get_name()))
|
||||
|
||||
writes: Set[dependencies.Dep] = {
|
||||
writes: OrderedSet[dependencies.Dep] = OrderedSet(
|
||||
StarDep(buf.get_name()) for buf in self.get_outputs()
|
||||
}
|
||||
)
|
||||
|
||||
return dependencies.ReadWrites(
|
||||
reads=reads,
|
||||
writes=writes,
|
||||
index_exprs=set(),
|
||||
index_exprs=OrderedSet(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -4331,8 +4329,8 @@ class ExternKernel(InputsKernel):
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
return [self, *self.mutation_outputs]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def collect_arg_kwarg_properties(self):
|
||||
# if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
|
||||
@ -4371,7 +4369,7 @@ class ExternKernel(InputsKernel):
|
||||
def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
|
||||
# Previously, we want to maintain forward-compatibility by skipping
|
||||
# default args in the serialized artifacts in fbcode. However,
|
||||
# some of our shim interfaces require default values being set.
|
||||
# some of our shim interfaces require default values being OrderedSet.
|
||||
# Discussed with Sherlock offline and we decided to allow serializing
|
||||
# default args into the C++ wrapper code for now. We will refine this
|
||||
# part if we see real FC requirement. More details related to FC
|
||||
@ -4860,17 +4858,17 @@ class ExternKernel(InputsKernel):
|
||||
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
|
||||
return index, tuple(new_sizes)
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
# NB: It's not necessary to check regular inputs as we automatically
|
||||
# have dependencies on them
|
||||
r = set()
|
||||
r: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
for arg in self.constant_args:
|
||||
r |= maybe_free_unbacked_symbols(arg)
|
||||
for arg in self.kwargs.values():
|
||||
r |= maybe_free_unbacked_symbols(arg)
|
||||
return r
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
kernel_name = getattr(self, "python_kernel_name", None)
|
||||
lines = [
|
||||
f"python_kernel_name={kernel_name!r}",
|
||||
@ -5056,13 +5054,13 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
new_name, raw_args, self.grid, configs, triton_meta, kernel.constexprs
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
# add unbacked symbols used in the grid to the ones used
|
||||
# in the kwargs (the latter is generated by ExternKernel)
|
||||
return super().get_unbacked_symbol_uses() | free_unbacked_symbols(self.grid)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(self, *, kernel_idx, grid, kernel_args):
|
||||
inputs = []
|
||||
@ -5144,8 +5142,8 @@ class InplaceBernoulliFallback(ExternKernel):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[0].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(self, op_overload, x, *constant_args):
|
||||
super().__init__(
|
||||
@ -5182,8 +5180,8 @@ class InplaceCopyFallback(ExternKernel):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[0].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -5237,8 +5235,8 @@ class MutatingFirstArgExternKernel(ExternKernel):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[0].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def has_side_effects(self):
|
||||
return True
|
||||
@ -5319,8 +5317,8 @@ class ScatterFallback(ExternKernel):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[0].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -5384,8 +5382,8 @@ class IndexPutFallback(ExternKernel):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[0].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(self, op_overload, x, indices, values, accumulate):
|
||||
self.indices = indices
|
||||
@ -5457,8 +5455,8 @@ class DynamicScalar(ExternKernel):
|
||||
self.sym = sym
|
||||
self.keypath = keypath
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return {self.sym}
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.sym])
|
||||
|
||||
def codegen(self, wrapper):
|
||||
wrapper.codegen_dynamic_scalar(self)
|
||||
@ -5518,22 +5516,24 @@ class ExternKernelNode:
|
||||
node: export_schema.Node
|
||||
|
||||
|
||||
has_c_shim = {
|
||||
aten._embedding_bag.default,
|
||||
aten._fft_c2c.default,
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
aten._scaled_dot_product_flash_attention.default,
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
aten._scaled_mm.default,
|
||||
aten.addmm.out,
|
||||
aten.bmm.out,
|
||||
aten.copy_.default,
|
||||
aten.mm.out,
|
||||
aten.repeat_interleave.Tensor,
|
||||
aten.nonzero.default,
|
||||
aten.view.dtype,
|
||||
aten.view_as_real.default,
|
||||
}
|
||||
has_c_shim = OrderedSet(
|
||||
[
|
||||
aten._embedding_bag.default,
|
||||
aten._fft_c2c.default,
|
||||
aten._scaled_dot_product_efficient_attention.default,
|
||||
aten._scaled_dot_product_flash_attention.default,
|
||||
aten._scaled_dot_product_cudnn_attention.default,
|
||||
aten._scaled_mm.default,
|
||||
aten.addmm.out,
|
||||
aten.bmm.out,
|
||||
aten.copy_.default,
|
||||
aten.mm.out,
|
||||
aten.repeat_interleave.Tensor,
|
||||
aten.nonzero.default,
|
||||
aten.view.dtype,
|
||||
aten.view_as_real.default,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FallbackKernel(ExternKernelAlloc):
|
||||
@ -5719,13 +5719,13 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {go_outer()}{wrapper.ending}"
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||
return resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, unbacked_bindings
|
||||
).keys()
|
||||
else:
|
||||
return set()
|
||||
return OrderedSet()
|
||||
|
||||
def set_cpp_kernel(self, kernel):
|
||||
from .codegen.wrapper import get_cpp_op_schema
|
||||
@ -5758,7 +5758,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
class Shim:
|
||||
ref: Any
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.ref
|
||||
|
||||
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
|
||||
@ -5784,7 +5784,9 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
if isinstance(example_output, torch.Tensor):
|
||||
return example_output.device
|
||||
if isinstance(example_output, (list, tuple)):
|
||||
device_set = {FallbackKernel.find_device(None, x) for x in example_output}
|
||||
device_set = OrderedSet(
|
||||
FallbackKernel.find_device(None, x) for x in example_output
|
||||
)
|
||||
# Remove None
|
||||
devices = [device for device in device_set if device]
|
||||
if len(devices) == 1:
|
||||
@ -6108,7 +6110,7 @@ class MultiOutput(ExternKernel):
|
||||
V.graph.register_operation(self)
|
||||
self.indices = indices
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return self.inputs[0].get_unbacked_symbol_uses()
|
||||
|
||||
def should_allocate(self):
|
||||
@ -6140,10 +6142,10 @@ class MutableBox(IRNode):
|
||||
def realize(self):
|
||||
return self.data.realize()
|
||||
|
||||
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
|
||||
def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]:
|
||||
return self.data.get_unbacked_symbol_uses()
|
||||
|
||||
def get_read_names(self) -> Set[str]:
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return self.data.get_read_names()
|
||||
|
||||
def get_defining_op(self):
|
||||
@ -6166,7 +6168,7 @@ class MutableBox(IRNode):
|
||||
def dtype(self):
|
||||
return self.data.dtype
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.data, MutableBox):
|
||||
line0 = f"{type(self).__name__}({type(self.data).__name__}("
|
||||
endl = "))"
|
||||
@ -6326,7 +6328,7 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
|
||||
for buffer in buffers
|
||||
]
|
||||
# assuming the same buffer is represented by the same IRNode object
|
||||
return len({id(buffer) for buffer in buffers}) < len(buffers)
|
||||
return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -7143,12 +7145,12 @@ class _WaitKernel(_CollectiveKernel):
|
||||
# NB: recursive structure here reflects val_to_arg_str, avoid
|
||||
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
|
||||
# treatment
|
||||
def maybe_free_unbacked_symbols(s: object) -> Set[Symbol]:
|
||||
def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]:
|
||||
if isinstance(s, (SymTypes, Expr)):
|
||||
# This branch should be impossible in return position
|
||||
return free_unbacked_symbols(s)
|
||||
elif isinstance(s, (tuple, list)):
|
||||
r = set()
|
||||
r: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
for t in s:
|
||||
r |= maybe_free_unbacked_symbols(t)
|
||||
return r
|
||||
@ -7156,4 +7158,4 @@ def maybe_free_unbacked_symbols(s: object) -> Set[Symbol]:
|
||||
# This branch is impossible in constant-args position
|
||||
return free_unbacked_symbols(s)
|
||||
else:
|
||||
return set()
|
||||
return OrderedSet()
|
||||
|
@ -1,10 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, List, Optional, Set
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._prims_common import make_channels_last_strides_for
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .ir import (
|
||||
ExternKernelAlloc,
|
||||
@ -437,8 +438,8 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
|
||||
self.cpp_kernel_overload_name,
|
||||
)
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@ -862,8 +863,8 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
||||
def get_mutation_names(self):
|
||||
return [self.inputs[self.idx_for_inplace_sum].get_name()]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
|
||||
return set()
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -34,6 +34,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, SymT
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -153,13 +154,13 @@ class SchedulerBuffer:
|
||||
class BaseSchedulerNode:
|
||||
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: Set[Dep]
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
|
||||
def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.node: Optional[ir.Operation] = node
|
||||
self.set_read_writes(node.get_read_writes())
|
||||
self.ancestors: Set[str] = set()
|
||||
self.ancestors: OrderedSet[str] = OrderedSet()
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
# e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
|
||||
# in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
|
||||
@ -167,9 +168,9 @@ class BaseSchedulerNode:
|
||||
# .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
|
||||
self.min_order: int
|
||||
self.max_order: int
|
||||
self.last_usage: Set[
|
||||
self.last_usage: OrderedSet[
|
||||
str
|
||||
] = set() # buffers that won't be used after this kernel
|
||||
] = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
|
||||
self.outputs: List[SchedulerBuffer] = [
|
||||
@ -255,10 +256,10 @@ class BaseSchedulerNode:
|
||||
self.prune_deps()
|
||||
|
||||
def set_last_usage(
|
||||
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
|
||||
) -> None:
|
||||
used_buffers = self.used_or_aliased_buffer_names()
|
||||
used_buffers = {mutation_real_name.get(k, k) for k in used_buffers}
|
||||
used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers])
|
||||
self.last_usage = used_buffers - future_used_buffers
|
||||
|
||||
def mark_run(self) -> None:
|
||||
@ -268,14 +269,14 @@ class BaseSchedulerNode:
|
||||
def op_counts(self) -> Counter[str]:
|
||||
return self.read_writes.op_counts
|
||||
|
||||
def used_buffer_names(self) -> Set[str]:
|
||||
return {
|
||||
def used_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(
|
||||
dep.name
|
||||
for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
|
||||
}
|
||||
)
|
||||
|
||||
def used_or_aliased_buffer_names(self) -> Set[str]:
|
||||
used_names = set()
|
||||
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
|
||||
used_names: OrderedSet[str] = OrderedSet()
|
||||
|
||||
deps = [
|
||||
dep.name
|
||||
@ -291,11 +292,11 @@ class BaseSchedulerNode:
|
||||
return used_names
|
||||
|
||||
def prune_deps(self) -> None:
|
||||
self.unmet_dependencies = {
|
||||
self.unmet_dependencies = OrderedSet(
|
||||
dep
|
||||
for dep in self.unmet_dependencies
|
||||
if dep.name not in self.scheduler.available_buffer_names
|
||||
}
|
||||
)
|
||||
|
||||
def prune_weak_deps(self) -> None:
|
||||
# Prune weak dependencies on operations that have been removed
|
||||
@ -305,7 +306,9 @@ class BaseSchedulerNode:
|
||||
op = self.scheduler.name_to_buf[dep.name].defining_op
|
||||
return op.get_name() in V.graph.removed_operations
|
||||
|
||||
to_remove = {dep for dep in self.read_writes.reads if should_prune(dep)}
|
||||
to_remove = OrderedSet(
|
||||
dep for dep in self.read_writes.reads if should_prune(dep)
|
||||
)
|
||||
self.set_read_writes(self.read_writes.remove_reads(to_remove))
|
||||
|
||||
def prune_redundant_deps(
|
||||
@ -320,11 +323,11 @@ class BaseSchedulerNode:
|
||||
def get_first_name(self) -> str:
|
||||
return self.get_name()
|
||||
|
||||
def get_operation_names(self) -> Set[str]:
|
||||
return {node.get_name() for node in self.get_nodes()}
|
||||
def get_operation_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(node.get_name() for node in self.get_nodes())
|
||||
|
||||
def get_buffer_names(self) -> Set[str]:
|
||||
return {out.get_name() for out in self.outputs}
|
||||
def get_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(out.get_name() for out in self.outputs)
|
||||
|
||||
def get_nodes(self) -> Sequence[BaseSchedulerNode]:
|
||||
return [self]
|
||||
@ -538,18 +541,18 @@ class BaseSchedulerNode:
|
||||
for dep in self.read_writes.reads | self.read_writes.writes:
|
||||
buf_accesses[dep.name].append(dep)
|
||||
|
||||
reads = {dep.name for dep in self.read_writes.reads}
|
||||
writes = {dep.name for dep in self.read_writes.writes}
|
||||
reads = OrderedSet(dep.name for dep in self.read_writes.reads)
|
||||
writes = OrderedSet(dep.name for dep in self.read_writes.writes)
|
||||
|
||||
def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool:
|
||||
users = self.scheduler.name_to_buf[buf].users
|
||||
buf_uses = {user.node for user in users}
|
||||
return len(buf_uses - set(snodes)) > 0
|
||||
buf_uses = OrderedSet(user.node for user in users)
|
||||
return len(buf_uses - OrderedSet(snodes)) > 0
|
||||
|
||||
if isinstance(self, FusedSchedulerNode):
|
||||
removed_buffers = {
|
||||
removed_buffers = OrderedSet(
|
||||
dep for dep in writes if not is_materialized(dep, self.snodes)
|
||||
}
|
||||
)
|
||||
writes = writes - removed_buffers
|
||||
reads = reads - removed_buffers
|
||||
node_bytes = 0
|
||||
@ -714,7 +717,7 @@ class WhyNoFuse:
|
||||
|
||||
|
||||
def pformat(obj: Any) -> str:
|
||||
if isinstance(obj, set):
|
||||
if isinstance(obj, OrderedSet):
|
||||
# pformat has trouble with sets of sympy exprs
|
||||
obj = sorted(obj, key=str)
|
||||
result = pprint.pformat(obj, indent=4)
|
||||
@ -725,7 +728,7 @@ def pformat(obj: Any) -> str:
|
||||
|
||||
class OutputNode:
|
||||
def __init__(self, dep: StarDep) -> None:
|
||||
self.unmet_dependencies = {dep}
|
||||
self.unmet_dependencies = OrderedSet([dep])
|
||||
|
||||
def is_reduction(self) -> bool:
|
||||
return False
|
||||
@ -771,14 +774,16 @@ def _prune_redundant_deps(
|
||||
else:
|
||||
return False
|
||||
|
||||
deps_to_prune = {dep for dep in node.unmet_dependencies if should_prune(dep)}
|
||||
deps_to_prune = OrderedSet(
|
||||
dep for dep in node.unmet_dependencies if should_prune(dep)
|
||||
)
|
||||
|
||||
if deps_to_prune:
|
||||
node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
|
||||
node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
|
||||
|
||||
|
||||
# TODO(xmfan): reuse an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
|
||||
# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel
|
||||
kernel_name_to_op = {
|
||||
"extern_kernels.convolution": torch.ops.aten.convolution,
|
||||
"extern_kernels.mm": torch.ops.aten.mm,
|
||||
@ -937,8 +942,8 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
return False
|
||||
|
||||
@cache_on_self
|
||||
def _get_atomic_add_buffers(self) -> Set[str]:
|
||||
buffers_store_as_atomic_add = set()
|
||||
def _get_atomic_add_buffers(self) -> OrderedSet[str]:
|
||||
buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
|
||||
if isinstance(self._body, ir.LoopBody):
|
||||
for node in self._body.get_nodes():
|
||||
if (
|
||||
@ -966,7 +971,7 @@ def init_group_node(
|
||||
group_snode.snodes = snodes
|
||||
group_snode.scheduler = scheduler
|
||||
group_snode.node = None
|
||||
group_snode.ancestors = set.union(
|
||||
group_snode.ancestors = OrderedSet.union(
|
||||
*[x.ancestors for x in snodes if x.ancestors is not None]
|
||||
)
|
||||
|
||||
@ -974,11 +979,14 @@ def init_group_node(
|
||||
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
|
||||
)
|
||||
|
||||
group_snode.unmet_dependencies = {
|
||||
dep
|
||||
for dep in set.union(*[x.unmet_dependencies for x in snodes])
|
||||
if dep.name not in group_snode.get_buffer_names()
|
||||
} - group_snode.read_writes.writes
|
||||
group_snode.unmet_dependencies = (
|
||||
OrderedSet(
|
||||
dep
|
||||
for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes])
|
||||
if dep.name not in group_snode.get_buffer_names()
|
||||
)
|
||||
- group_snode.read_writes.writes
|
||||
)
|
||||
|
||||
group_snode.min_order = min(x.min_order for x in group_snode.snodes)
|
||||
group_snode.max_order = max(x.max_order for x in group_snode.snodes)
|
||||
@ -1020,8 +1028,8 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
return self.snodes[0].get_name()
|
||||
|
||||
@cache_on_self
|
||||
def get_buffer_names(self) -> Set[str]:
|
||||
return set.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
def get_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
|
||||
def get_outputs(self) -> List[SchedulerBuffer]:
|
||||
result: List[SchedulerBuffer] = []
|
||||
@ -1047,25 +1055,27 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
return f"{self}, snodes: {snodes_str}"
|
||||
|
||||
def set_last_usage(
|
||||
self, future_used_buffers: Set[str], mutation_real_name: Dict[str, str]
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
|
||||
) -> None:
|
||||
# Set self.last_usage using the global information
|
||||
# This will be used for inter-kernel optimisations
|
||||
super().set_last_usage(future_used_buffers, mutation_real_name)
|
||||
# Set self.last_usage on the snodes
|
||||
# This will be used for optimisations within the kernel
|
||||
future_used_buffers: Set[str] = set()
|
||||
future_used_buffers: OrderedSet[str] = OrderedSet()
|
||||
for node in reversed(self.snodes):
|
||||
node.set_last_usage(future_used_buffers, mutation_real_name)
|
||||
future_used_buffers.update(node.last_usage)
|
||||
|
||||
@cache_on_self
|
||||
def used_buffer_names(self) -> Set[str]:
|
||||
return set.union(*[x.used_buffer_names() for x in self.snodes])
|
||||
def used_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes])
|
||||
|
||||
@cache_on_self
|
||||
def used_or_aliased_buffer_names(self) -> Set[str]:
|
||||
return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes])
|
||||
def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(
|
||||
*[x.used_or_aliased_buffer_names() for x in self.snodes]
|
||||
)
|
||||
|
||||
def get_nodes(self) -> Sequence[BaseSchedulerNode]:
|
||||
return self.snodes
|
||||
@ -1300,13 +1310,16 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
)
|
||||
)
|
||||
|
||||
self.unmet_dependencies = {
|
||||
dep
|
||||
for dep in set.union(
|
||||
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
|
||||
self.unmet_dependencies = (
|
||||
OrderedSet(
|
||||
dep
|
||||
for dep in OrderedSet.union(
|
||||
prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
|
||||
)
|
||||
if dep.name not in self.get_buffer_names()
|
||||
)
|
||||
if dep.name not in self.get_buffer_names()
|
||||
} - self.read_writes.writes
|
||||
- self.read_writes.writes
|
||||
)
|
||||
|
||||
self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
|
||||
self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
|
||||
@ -1327,7 +1340,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
|
||||
self.group = (nodes[0].get_device(), ((sympy.Expr("foreach"),),))
|
||||
|
||||
self.origins: Set[torch.fx.Node] = set()
|
||||
self.origins: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
|
||||
def mark_run(self) -> None:
|
||||
raise NotImplementedError
|
||||
@ -1409,8 +1422,8 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
return self.snodes[0].get_name()
|
||||
|
||||
@cache_on_self
|
||||
def get_buffer_names(self) -> Set[str]:
|
||||
return set.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
def get_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
|
||||
def get_outputs(self) -> List[SchedulerBuffer]:
|
||||
result: List[SchedulerBuffer] = []
|
||||
@ -1519,12 +1532,14 @@ class Scheduler:
|
||||
self.backends: Dict[torch.device, BaseScheduling] = {}
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
|
||||
self.completed_operations: Set[str] = set()
|
||||
self.available_buffer_names = {
|
||||
*V.graph.graph_inputs.keys(),
|
||||
*V.graph.constants.keys(),
|
||||
*V.graph.torchbind_constants.keys(),
|
||||
}
|
||||
self.completed_operations: OrderedSet[str] = OrderedSet()
|
||||
self.available_buffer_names = OrderedSet(
|
||||
[
|
||||
*V.graph.graph_inputs.keys(),
|
||||
*V.graph.constants.keys(),
|
||||
*V.graph.torchbind_constants.keys(),
|
||||
]
|
||||
)
|
||||
|
||||
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
||||
|
||||
@ -1575,7 +1590,7 @@ class Scheduler:
|
||||
self.num_orig_nodes = len(self.nodes)
|
||||
self.create_foreach_nodes()
|
||||
self.nodes = self.topological_sort_schedule(self.nodes)
|
||||
self.logged_slow_fusion: Set[Tuple[str, str]] = set()
|
||||
self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet()
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
||||
self.nodes = self.fuse_nodes(self.nodes)
|
||||
@ -1590,7 +1605,7 @@ class Scheduler:
|
||||
|
||||
# used during codegen:
|
||||
self.current_device: Optional[torch.device] = None
|
||||
self.buffer_names_to_free: Set[str] = set()
|
||||
self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
|
||||
|
||||
# fx graph node to the position it appears in the graph
|
||||
# for debug attribution
|
||||
@ -1637,7 +1652,7 @@ class Scheduler:
|
||||
raise NotImplementedError(node)
|
||||
|
||||
def create_foreach_nodes(self) -> None:
|
||||
removed_node_names = set()
|
||||
removed_node_names: OrderedSet[str] = OrderedSet()
|
||||
fe_nodes = []
|
||||
kept_node_names = self.name_to_fused_node.keys()
|
||||
|
||||
@ -1678,7 +1693,7 @@ class Scheduler:
|
||||
"""
|
||||
This data structure behaves like a list except it makes sure the
|
||||
elements remain unique.
|
||||
Normally one could use a set/dict for this purpose however
|
||||
Normally one could use a OrderedSet/dict for this purpose however
|
||||
the list in question gets elements appended as it is being
|
||||
iterated over which means that we need to keep the list
|
||||
semantics.
|
||||
@ -1687,10 +1702,10 @@ class Scheduler:
|
||||
def __init__(
|
||||
self,
|
||||
items: Optional[List[T]] = None,
|
||||
membership: Optional[Set[T]] = None,
|
||||
membership: Optional[OrderedSet[T]] = None,
|
||||
) -> None:
|
||||
self.items = items or []
|
||||
self.membership = membership or set()
|
||||
self.membership = membership or OrderedSet()
|
||||
|
||||
def append(self, node_user: T) -> None:
|
||||
if node_user in self.membership:
|
||||
@ -1699,7 +1714,7 @@ class Scheduler:
|
||||
self.membership.add(node_user)
|
||||
|
||||
def __add__(self, other: DedupList[T]) -> DedupList[T]:
|
||||
new_membership = set.union(self.membership, other.membership)
|
||||
new_membership = OrderedSet.union(self.membership, other.membership)
|
||||
new_items = self.items + [
|
||||
x for x in other.items if x not in self.membership
|
||||
]
|
||||
@ -1915,7 +1930,7 @@ class Scheduler:
|
||||
"""
|
||||
Ensure nodes is in topologically sorted order
|
||||
"""
|
||||
seen: Set[BaseSchedulerNode] = set()
|
||||
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
|
||||
name_to_node: Dict[str, BaseSchedulerNode] = dict()
|
||||
result: List[BaseSchedulerNode] = []
|
||||
|
||||
@ -1941,9 +1956,9 @@ class Scheduler:
|
||||
Populate each node.ancestors
|
||||
"""
|
||||
# note self.nodes is topologically sorted
|
||||
name_to_ancestors: Dict[str, Set[str]] = {}
|
||||
name_to_ancestors: Dict[str, OrderedSet[str]] = {}
|
||||
for node in self.nodes:
|
||||
ancestors = set()
|
||||
ancestors: OrderedSet[str] = OrderedSet()
|
||||
for dep in node.unmet_dependencies:
|
||||
dep_node_name = self.name_to_buf[dep.name].defining_op.get_name()
|
||||
ancestors.add(dep_node_name)
|
||||
@ -2229,7 +2244,7 @@ class Scheduler:
|
||||
- self.can_fuse(): checks if a fusion is legal
|
||||
- self.score_fusion(): assigns priority to a given fusion
|
||||
"""
|
||||
fused_nodes = set(nodes)
|
||||
fused_nodes = OrderedSet(nodes)
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
fusion_log.debug("fuse_nodes_once, candidates:")
|
||||
for node in fused_nodes:
|
||||
@ -2271,7 +2286,7 @@ class Scheduler:
|
||||
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
|
||||
"""
|
||||
possible_fusions = []
|
||||
seen = set()
|
||||
seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet()
|
||||
|
||||
def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:
|
||||
for node1_index, node1 in enumerate(nodes):
|
||||
@ -2319,8 +2334,8 @@ class Scheduler:
|
||||
Finds whether there's a path from node1 to node2 (or vice-versa)
|
||||
caused indirectly by other fusions.
|
||||
"""
|
||||
|
||||
visited = set()
|
||||
# since we are just returning boolean here, use slightly faster, unordered set
|
||||
visited: Set[FusedSchedulerNode] = set()
|
||||
|
||||
def found_path(node: BaseSchedulerNode) -> bool:
|
||||
# only fused nodes can introduce new ancestors.
|
||||
@ -2345,8 +2360,14 @@ class Scheduler:
|
||||
)
|
||||
return False
|
||||
|
||||
combined_names = node1.get_operation_names() | node2.get_operation_names()
|
||||
combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names
|
||||
# as above - use slightly faster, unordered set
|
||||
combined_names = (
|
||||
node1.get_operation_names()._dict.keys()
|
||||
| node2.get_operation_names()._dict.keys()
|
||||
)
|
||||
combined_ancestors = (
|
||||
node1.ancestors._dict.keys() | node2.ancestors._dict.keys()
|
||||
) - combined_names
|
||||
cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
|
||||
if cycle:
|
||||
WhyNoFuse(node1, node2)("will create cycle")
|
||||
@ -2555,7 +2576,7 @@ class Scheduler:
|
||||
"""
|
||||
node1_buf_names = node1.get_buffer_names()
|
||||
node1_op_names = node1.get_operation_names()
|
||||
computed_deps = set()
|
||||
computed_deps: OrderedSet[Dep] = OrderedSet()
|
||||
why = WhyNoFuse(node1, node2)
|
||||
|
||||
for cd in node1.read_writes.writes:
|
||||
@ -2569,7 +2590,9 @@ class Scheduler:
|
||||
if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2):
|
||||
computed_deps.add(dep)
|
||||
|
||||
remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps}
|
||||
remaining_deps = OrderedSet(
|
||||
dep.name for dep in node2.unmet_dependencies - computed_deps
|
||||
)
|
||||
if remaining_deps & node1_buf_names:
|
||||
# MemoryDeps didn't match and read different locations of the same buffer.
|
||||
# Examples here include:
|
||||
@ -2695,6 +2718,23 @@ class Scheduler:
|
||||
The first term in our fusion score that estimates number of saved
|
||||
memory operations.
|
||||
"""
|
||||
node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes)
|
||||
node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes)
|
||||
|
||||
# optimization: iter over smaller set
|
||||
if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len):
|
||||
if node1_dep_len > node2_dep_len:
|
||||
tmp = node1
|
||||
node1 = node2
|
||||
node2 = tmp
|
||||
|
||||
deps = []
|
||||
for dep in node1.read_writes.reads | node1.read_writes.writes:
|
||||
if dep in node2.read_writes.reads or dep in node2.read_writes.writes:
|
||||
deps.append(dep)
|
||||
|
||||
return sum(self.dep_size_hint(dep) for dep in deps)
|
||||
|
||||
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
|
||||
node2.read_writes.reads | node2.read_writes.writes
|
||||
)
|
||||
@ -2746,7 +2786,7 @@ class Scheduler:
|
||||
Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
|
||||
"""
|
||||
|
||||
future_used_buffers = set(V.graph.get_output_names())
|
||||
future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
||||
|
||||
for node in reversed(self.nodes):
|
||||
node.set_last_usage(future_used_buffers, self.mutation_real_name)
|
||||
@ -2776,11 +2816,11 @@ class Scheduler:
|
||||
same kernel can be removed.
|
||||
"""
|
||||
|
||||
fused_node_names = {
|
||||
fused_node_names = OrderedSet(
|
||||
self.name_to_buf[buf].defining_op.get_name()
|
||||
for buf in V.kernel.store_buffer_names
|
||||
if buf in self.name_to_buf
|
||||
}
|
||||
)
|
||||
names_to_remove = []
|
||||
for out_buf in V.kernel.store_buffer_names:
|
||||
if out_buf not in self.name_to_buf:
|
||||
@ -2789,7 +2829,7 @@ class Scheduler:
|
||||
continue
|
||||
users = self.name_to_buf[out_buf].users
|
||||
assert users is not None
|
||||
users = {user.get_name() for user in users if not user.is_weak}
|
||||
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
|
||||
if users.issubset(fused_node_names):
|
||||
names_to_remove.append(out_buf)
|
||||
|
||||
|
@ -400,7 +400,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.body.writeline(f"{output_name} = {out.value}")
|
||||
|
||||
body_val = self.body.getvalue()
|
||||
self.cse.invalidate(set())
|
||||
self.cse.invalidate(set()) # type: ignore[arg-type]
|
||||
return body_val
|
||||
|
||||
def store_output(
|
||||
@ -600,7 +600,7 @@ class TritonTemplate(KernelTemplate):
|
||||
self.all_templates[name] = self
|
||||
self.debug = debug
|
||||
|
||||
def generate(
|
||||
def generate( # type: ignore[override]
|
||||
self,
|
||||
input_nodes,
|
||||
layout,
|
||||
|
Reference in New Issue
Block a user