mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Previously in merge_loops, we have to construct LoopBody twice to make sure we can use the same symbol prefix as before. This PR change it to create LoopBody only once by allowing using the same symbol prefix for the new LoopBody.
In looks like it's ok to have duplicate symbols in sympy replacement:
```
>>> x, y = sympy.symbols("x y")
>>> (x + y).xreplace({x: 0, y: x + 1})
x + 1
>>> (x + y).xreplace({x: y * y, y: x + 1})
x + y**2 + 1
>>> (x + y + x * x).xreplace({x: 0, y: x})
x
```
UPDATE: add the same optimization for LoopBody.reorder_iter_loops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162101
Approved by: https://github.com/jansel, https://github.com/eellison
738 lines
25 KiB
Python
738 lines
25 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import functools
|
|
import itertools
|
|
import re
|
|
from enum import auto, Enum
|
|
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, TypeVar
|
|
|
|
import sympy
|
|
|
|
import torch.fx
|
|
from torch._dynamo.utils import identity
|
|
from torch.fx.proxy import Scope, TracerBase
|
|
from torch.utils._sympy.symbol import SymT
|
|
|
|
from . import config, dependencies
|
|
from .codegen.common import index_prevent_reordering
|
|
from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler
|
|
from .utils import (
|
|
cache_on_self,
|
|
reduction_num_outputs,
|
|
sympy_index_symbol_with_prefix,
|
|
sympy_subs,
|
|
)
|
|
from .virtualized import ops, V
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class InterpreterShim(torch.fx.Interpreter):
|
|
@staticmethod
|
|
@functools.cache
|
|
def _dummy_gm():
|
|
return torch.fx.symbolic_trace(identity)
|
|
|
|
def __init__(self, graph, submodules):
|
|
# call super() with a placeholder to avoid constructing a
|
|
# GraphModule which is very expensive (it does codegen).
|
|
super().__init__(self._dummy_gm(), garbage_collect_values=False)
|
|
self.module = self # type: ignore[assignment]
|
|
self.graph = graph
|
|
self.submodules = submodules
|
|
self.extra_traceback = False
|
|
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
|
|
self.current_node = None
|
|
|
|
def run_node(self, n: torch.fx.Node) -> Any:
|
|
self.current_node = n
|
|
return super().run_node(n)
|
|
|
|
def run(self, *args, **kwargs):
|
|
with V.set_interpreter_handler(self):
|
|
return super().run(*args, **kwargs)
|
|
|
|
|
|
# We don't need the nn.Module and constant handling in Tracer
|
|
class LightTracer(TracerBase):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.graph = torch.fx.Graph(tracer_cls=self.__class__) # type: ignore[arg-type]
|
|
self.scope = Scope("", None)
|
|
self.module_stack = {} # type: ignore[assignment]
|
|
self.node_name_to_scope = {}
|
|
|
|
|
|
class MemoryEntry(NamedTuple):
|
|
index_name: str # LoopBody.indexing_exprs[index_name]
|
|
buffer_name: Optional[str]
|
|
mode: Optional[str] # V.ops.store(..., mode=mode)
|
|
|
|
|
|
class MemoryUsageType(Enum):
|
|
# These are 1:1 with the opcode generating the usage
|
|
LOAD = auto()
|
|
LOAD_SEED = auto()
|
|
STORE = auto()
|
|
STORE_REDUCTION = auto()
|
|
INDEX_EXPR = auto()
|
|
CHECK_BOUNDS = auto()
|
|
BUCKETIZE = auto()
|
|
|
|
|
|
class LoopBody:
|
|
"""
|
|
Captures the body of a Loops subclass into an FX graph. Persists any
|
|
indexing simplifications and makes it easier to analyze loop bodies.
|
|
"""
|
|
|
|
indexing_exprs: dict[str, sympy.Expr]
|
|
indexing_exprs_name: dict[sympy.Expr, str]
|
|
submodules: dict[str, Any]
|
|
subblocks: dict[str, LoopBodyBlock]
|
|
indirect_vars: list[sympy.Symbol]
|
|
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr]
|
|
root_block: LoopBodyBlock
|
|
memory_usage: dict[MemoryUsageType, list[MemoryEntry]]
|
|
op_counts: collections.Counter[str]
|
|
|
|
def __init__(
|
|
self,
|
|
fn,
|
|
args,
|
|
var_ranges,
|
|
iter_vars,
|
|
reduce_vars,
|
|
allow_same_symbol_in_index=False,
|
|
):
|
|
super().__init__()
|
|
|
|
_flat_sizes = tuple(var_ranges.values())
|
|
self.sizes = (
|
|
_flat_sizes[: len(iter_vars)],
|
|
_flat_sizes[len(iter_vars) :],
|
|
)
|
|
|
|
self.iter_vars = iter_vars
|
|
self.reduce_vars = reduce_vars
|
|
self.var_ranges = var_ranges
|
|
|
|
if isinstance(fn, LoopBody):
|
|
self._init_with_copy(fn, args, allow_same_symbol_in_index)
|
|
else:
|
|
self._init_with_tracing(fn, args)
|
|
|
|
self.indexing = None
|
|
|
|
def _init_with_tracing(self, fn, args):
|
|
"""Do an FX trace of an arbitrary callable to construct self"""
|
|
self.indexing_exprs = {}
|
|
self.indexing_exprs_name = {}
|
|
self.submodules = {"get_index": self.get_index}
|
|
self.subblocks = {}
|
|
self.indirect_vars = []
|
|
self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {}
|
|
self.memory_usage = {t: [] for t in MemoryUsageType}
|
|
self.op_counts = collections.Counter()
|
|
self.root_block = LoopBodyBlock(self, fn, args) # traces
|
|
del self.indexing_exprs_name # not used after _init_with_tracing
|
|
|
|
def _init_with_copy(self, other: LoopBody, args, allow_same_symbol_in_index):
|
|
"""
|
|
_init_with_tracing() is slow, so this is a fast path in the case
|
|
where we are just reordering/merging/splitting the args of an
|
|
existing LoopBody.
|
|
"""
|
|
indexing_exprs = other.indexing_from_args(args, allow_same_symbol_in_index)
|
|
self.indexing_exprs = {
|
|
name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges)
|
|
for name, expr in indexing_exprs.items()
|
|
}
|
|
self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()}
|
|
self.indirect_vars = other.indirect_vars
|
|
self.indirect_var_ranges = other.indirect_var_ranges
|
|
self.memory_usage = other.memory_usage
|
|
self.op_counts = other.op_counts
|
|
self.root_block = other.root_block.clone(self)
|
|
|
|
submodules = {**other.submodules}
|
|
submodules.pop("get_index")
|
|
self.submodules = {
|
|
"get_index": self.get_index,
|
|
**{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined]
|
|
}
|
|
|
|
def has_op(self, name: str):
|
|
return self.op_counts.get(name, 0) > 0
|
|
|
|
def merge_loops(self) -> LoopBody:
|
|
"""
|
|
Merge both iteration and reduction loops and return a new LoopBody.
|
|
"""
|
|
old_body = self
|
|
old_sizes = self.sizes
|
|
old_iter_vars, old_reduce_vars = old_body.vars
|
|
old_iter_sizes, old_reduce_sizes = old_sizes
|
|
|
|
index_exprs = [*old_body.indexing_exprs.values()]
|
|
|
|
iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops(
|
|
old_iter_vars,
|
|
old_iter_sizes,
|
|
index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes),
|
|
)
|
|
|
|
reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops(
|
|
old_reduce_vars,
|
|
old_reduce_sizes,
|
|
index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes),
|
|
)
|
|
|
|
if iter_sizes == old_iter_sizes and reduce_sizes == old_reduce_sizes:
|
|
return old_body
|
|
|
|
(
|
|
(
|
|
iter_vars,
|
|
reduce_vars,
|
|
),
|
|
var_ranges,
|
|
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="p")
|
|
new_body = LoopBody(
|
|
old_body,
|
|
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
|
var_ranges,
|
|
iter_vars,
|
|
reduce_vars,
|
|
allow_same_symbol_in_index=True,
|
|
)
|
|
|
|
return new_body
|
|
|
|
def expand_dimension_for_pointwise_node(
|
|
self, dimension: int, new_range: int
|
|
) -> LoopBody:
|
|
"""
|
|
Expand node on `dimension` to `new_range` and rely on index modular to avoid
|
|
out-of-boundary access.
|
|
"""
|
|
|
|
old_body = self
|
|
old_sizes = self.sizes
|
|
|
|
iter_size, reduce_size = old_sizes
|
|
original_range = iter_size[dimension]
|
|
new_iter_size = list(iter_size)
|
|
new_iter_size[dimension] = new_range
|
|
new_sizes = (new_iter_size, reduce_size)
|
|
|
|
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
|
*new_sizes,
|
|
prefix="t", # type: ignore[arg-type]
|
|
)
|
|
|
|
def new_body(*indices: Sequence[sympy.Expr]) -> Any:
|
|
index = [*itertools.chain.from_iterable(indices)]
|
|
assert len(index) == len(iter_size) + len(reduce_size)
|
|
iter_idx = index[: len(iter_size)]
|
|
reduce_idx = index[len(iter_size) :]
|
|
|
|
new_iter_idx = list(iter_idx)
|
|
new_iter_idx[dimension] = iter_idx[dimension] % original_range
|
|
|
|
return old_body(new_iter_idx, reduce_idx)
|
|
|
|
loop_body = LoopBody(
|
|
new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars
|
|
)
|
|
|
|
# use the original symbol prefix so we can do multiple round of reordering
|
|
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
|
*new_sizes,
|
|
prefix="p", # type: ignore[arg-type]
|
|
)
|
|
new_body = LoopBody(
|
|
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
|
)
|
|
return new_body
|
|
|
|
def reorder_iter_loops(self, new_order) -> LoopBody:
|
|
"""
|
|
Reorder iteration loops and return a new LoopBody.
|
|
"""
|
|
from .ir import same_reorder
|
|
|
|
old_body = self
|
|
old_sizes = self.sizes
|
|
assert len(old_sizes[0]) == len(new_order)
|
|
reorder_fn = same_reorder(new_order)
|
|
|
|
iter_size, reduce_size = old_sizes
|
|
new_iter_size = reorder_fn(iter_size)
|
|
|
|
new_sizes = (new_iter_size, reduce_size)
|
|
|
|
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
|
*new_sizes,
|
|
prefix="p", # type: ignore[arg-type]
|
|
)
|
|
|
|
inverse_order = {b: a for a, b in enumerate(new_order)}
|
|
inverse_order = [inverse_order[i] for i in range(len(new_order))]
|
|
|
|
def new_body(*indices: Sequence[sympy.Expr]) -> Any:
|
|
index = [*itertools.chain.from_iterable(indices)]
|
|
assert len(index) == len(iter_size) + len(reduce_size)
|
|
iter_idx = index[: len(iter_size)]
|
|
reduce_idx = index[len(iter_size) :]
|
|
iter_idx = [iter_idx[i] for i in inverse_order]
|
|
return old_body(iter_idx, reduce_idx, allow_same_symbol_in_index=True)
|
|
|
|
return LoopBody(
|
|
new_body,
|
|
(iter_vars, reduce_vars),
|
|
var_ranges,
|
|
iter_vars,
|
|
reduce_vars,
|
|
)
|
|
|
|
@property
|
|
def vars(self):
|
|
assert self.iter_vars is not None
|
|
assert self.reduce_vars is not None
|
|
return self.iter_vars, self.reduce_vars
|
|
|
|
@cache_on_self
|
|
def get_nodes(self):
|
|
all_graphs = itertools.chain(
|
|
(self.root_block.graph,),
|
|
(block.graph for block in self.subblocks.values()),
|
|
)
|
|
return [node for graph in all_graphs for node in graph.nodes]
|
|
|
|
@cache_on_self
|
|
def bounds(self):
|
|
# Doing a local import to avoid dumping all the code here
|
|
from .bounds import BoundVars
|
|
|
|
return BoundVars(self)
|
|
|
|
def get_read_expr(self, buffer_name):
|
|
# reversed to match old behavior
|
|
for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]):
|
|
if entry.buffer_name == buffer_name:
|
|
return self.indexing_exprs[entry.index_name]
|
|
raise KeyError(buffer_name)
|
|
|
|
def get_write_expr(self, buffer_name):
|
|
for entry in itertools.chain(
|
|
self.memory_usage[MemoryUsageType.STORE],
|
|
self.memory_usage[MemoryUsageType.STORE_REDUCTION],
|
|
):
|
|
if entry.buffer_name == buffer_name:
|
|
return self.indexing_exprs[entry.index_name]
|
|
raise KeyError(buffer_name)
|
|
|
|
def get_read_exprs(self):
|
|
return [
|
|
self.indexing_exprs[entry.index_name]
|
|
for entry in self.memory_usage[MemoryUsageType.LOAD]
|
|
]
|
|
|
|
def get_all_read_expr(self, buffer_name):
|
|
# reversed to match old behavior
|
|
out = []
|
|
for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]):
|
|
if entry.buffer_name == buffer_name:
|
|
out.append(self.indexing_exprs[entry.index_name])
|
|
return out
|
|
|
|
def get_write_exprs(self):
|
|
return [
|
|
self.indexing_exprs[entry.index_name]
|
|
for entry in itertools.chain(
|
|
self.memory_usage[MemoryUsageType.STORE],
|
|
self.memory_usage[MemoryUsageType.STORE_REDUCTION],
|
|
)
|
|
]
|
|
|
|
def get_all_write_expr(self, buffer_name):
|
|
out = []
|
|
for entry in itertools.chain(
|
|
self.memory_usage[MemoryUsageType.STORE],
|
|
self.memory_usage[MemoryUsageType.STORE_REDUCTION],
|
|
):
|
|
if entry.buffer_name == buffer_name:
|
|
out.append(self.indexing_exprs[entry.index_name])
|
|
return out
|
|
|
|
def debug_str(self):
|
|
lines = [f"var_ranges = {dict(self.var_ranges)}"]
|
|
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
|
|
lines.extend(
|
|
[
|
|
block.debug_str(name)
|
|
for name, block in itertools.chain(
|
|
[("body", self.root_block)], self.subblocks.items()
|
|
)
|
|
]
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
def is_memory_copy(self) -> bool:
|
|
"""
|
|
True of this contains only a single loads and store.
|
|
Note, this could involve a layout change.
|
|
"""
|
|
return (
|
|
len(self.memory_usage[MemoryUsageType.LOAD]) == 1
|
|
and len(self.memory_usage[MemoryUsageType.STORE]) == 1
|
|
and len(self.submodules) == 1 # get_index
|
|
and self.root_block.contains_only_ops(("load", "store"))
|
|
)
|
|
|
|
__repr__ = debug_str
|
|
|
|
def add_index_expr(
|
|
self,
|
|
expr: sympy.Expr,
|
|
mtype: MemoryUsageType,
|
|
buffer_name: Optional[str] = None,
|
|
mode: Optional[str] = None,
|
|
):
|
|
name = self.indexing_exprs_name.get(expr)
|
|
if not name:
|
|
name = f"index{len(self.indexing_exprs)}"
|
|
self.indexing_exprs_name[expr] = name
|
|
self.indexing_exprs[name] = expr
|
|
self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode))
|
|
return name
|
|
|
|
def add_submodule(self, block, prefix):
|
|
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
|
|
if prefix[-1].isnumeric() and prefix not in self.submodules:
|
|
name = prefix
|
|
else:
|
|
name = f"{prefix}{len(self.submodules)}"
|
|
self.submodules[name] = block
|
|
return name
|
|
|
|
def add_indirect(self, size):
|
|
var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
|
|
assert var not in self.indirect_var_ranges
|
|
self.indirect_vars.append(var)
|
|
self.indirect_var_ranges[var] = size
|
|
return var
|
|
|
|
def replace_indirect(self, old, new):
|
|
"""Swap in a variable used in indirect indexing"""
|
|
if str(old) == str(new):
|
|
return
|
|
assert self.indexing is not None
|
|
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
|
|
|
|
def get_index(self, name):
|
|
assert self.indexing is not None
|
|
return self.indexing[name]
|
|
|
|
def indexing_from_args(self, indices, allow_same_symbol_in_index=False):
|
|
index = [*itertools.chain.from_iterable(indices)]
|
|
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
|
assert allow_same_symbol_in_index or all(
|
|
v not in self.var_ranges for v in index
|
|
), f"{self.var_ranges=}, {indices=}"
|
|
|
|
replacements = dict(zip(self.var_ranges.keys(), index))
|
|
return {
|
|
name: sympy_subs(expr, replacements)
|
|
for name, expr in self.indexing_exprs.items()
|
|
}
|
|
|
|
def __call__(self, *indices, allow_same_symbol_in_index=False):
|
|
self.indexing = self.indexing_from_args(indices, allow_same_symbol_in_index)
|
|
result = self.root_block()
|
|
self.indexing = None
|
|
return result
|
|
|
|
def bind_set_indirect_shim(self, var, size, check, wrap_neg):
|
|
def set_indirect(new_var):
|
|
self.replace_indirect(
|
|
var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
|
|
)
|
|
|
|
set_indirect.clone = functools.partial( # type: ignore[attr-defined]
|
|
LoopBody.bind_set_indirect_shim,
|
|
var=var,
|
|
size=size,
|
|
check=check,
|
|
wrap_neg=wrap_neg,
|
|
)
|
|
return set_indirect
|
|
|
|
def bind_scan_shim(self, combine_fn):
|
|
def shim(dtypes, values):
|
|
return V.ops.scan(dtypes, combine_fn, values)
|
|
|
|
shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined]
|
|
return shim
|
|
|
|
def bind_masked_shim(self, name):
|
|
def shim(mask, other):
|
|
return V.ops.masked(mask, self.subblocks[name], other)
|
|
|
|
shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined]
|
|
return shim
|
|
|
|
|
|
class LoopBodyBlock:
|
|
"""
|
|
Captures the body of a Loops subclass into an FX graph.
|
|
In normal cases there will be a 1:1 mapping between LoopBody and
|
|
LoopBodyBlock, however in the case of ops.masked() the masked out
|
|
operations will manifest as an extra LoopBodyBlock.
|
|
"""
|
|
|
|
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
|
|
self.body = body
|
|
|
|
tracer = LightTracer()
|
|
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
|
|
|
from .index_propagation import IndexPropagation
|
|
|
|
handler: Any = CountOps(
|
|
CaptureIndexing(proxy_ops, body, tracer),
|
|
body.op_counts,
|
|
)
|
|
if config.constant_and_index_propagation:
|
|
handler = IndexPropagation(
|
|
handler, self.body.var_ranges, self.body.indirect_var_ranges
|
|
)
|
|
|
|
with V.set_ops_handler(handler):
|
|
# This indirection is just a cute way to get IndexPropagation to
|
|
# unwrap the return value.
|
|
ops.output(fn(*args))
|
|
self.graph = tracer.graph
|
|
|
|
def __call__(self):
|
|
graph = self.graph
|
|
submodules = self.body.submodules
|
|
|
|
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
|
|
|
|
def debug_str(self, name="block"):
|
|
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
|
|
return re.sub(
|
|
# strip `; del var0` suffixes to make output prettier
|
|
r";[^\n]*",
|
|
"",
|
|
code.strip().replace("def forward(", f"def {name}("),
|
|
)
|
|
|
|
def contains_only_ops(self, allowed_ops) -> bool:
|
|
return all(
|
|
node.target in allowed_ops
|
|
for node in self.graph.find_nodes(op="call_method")
|
|
)
|
|
|
|
def clone(self, body: LoopBody):
|
|
"""Shallow copy with a new parent LoopBody"""
|
|
copy = LoopBodyBlock.__new__(LoopBodyBlock)
|
|
copy.__dict__.update({**self.__dict__, "body": body})
|
|
return copy
|
|
|
|
|
|
class CountOps(DefaultHandler):
|
|
def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]):
|
|
self._inner = inner
|
|
self._counts = counts
|
|
|
|
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
|
self._counts[name] += 1
|
|
return getattr(self._inner, name)(*args, **kwargs)
|
|
|
|
|
|
class CaptureIndexing(WrapperHandler):
|
|
name = "CaptureIndexing"
|
|
|
|
def __init__(
|
|
self,
|
|
inner: OpsHandler[Any],
|
|
body: LoopBody,
|
|
tracer: LightTracer,
|
|
):
|
|
super().__init__(inner)
|
|
self.body = body
|
|
self.tracer = tracer
|
|
|
|
def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any):
|
|
return self.tracer.create_proxy(
|
|
"call_module",
|
|
"get_index",
|
|
(self.body.add_index_expr(expr, mtype, **kwargs),),
|
|
{},
|
|
)
|
|
|
|
def _simplify(self, expr: sympy.Expr) -> sympy.Expr:
|
|
return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges)
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
index = self._simplify(index)
|
|
index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
|
return self._inner.load(name, index)
|
|
|
|
def load_seed(self, name: str, index: int):
|
|
assert isinstance(index, int)
|
|
self.body.add_index_expr(
|
|
sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
|
|
)
|
|
return self._inner.load_seed(name, index)
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
index = self._simplify(index)
|
|
index = self._add_index(
|
|
index, MemoryUsageType.STORE, buffer_name=name, mode=mode
|
|
)
|
|
return self._inner.store(name, index, value, mode)
|
|
|
|
def store_reduction(self, name, index, value):
|
|
index = self._simplify(index)
|
|
index = self._add_index(
|
|
index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
|
|
)
|
|
return self._inner.store_reduction(name, index, value)
|
|
|
|
def reduction(self, dtype, src_dtype, reduction_type, value):
|
|
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
|
num_outputs = reduction_num_outputs(reduction_type)
|
|
if num_outputs > 1:
|
|
return tuple(result[i] for i in range(num_outputs))
|
|
return result
|
|
|
|
def index_expr(self, index, dtype):
|
|
index = self._simplify(index)
|
|
if isinstance(index, (int, sympy.Integer)):
|
|
return self._inner.constant(int(index), dtype)
|
|
index = self._add_index(index, MemoryUsageType.INDEX_EXPR)
|
|
return self._inner.index_expr(index, dtype)
|
|
|
|
def check_bounds(self, index, size, lower, upper):
|
|
index = self._simplify(index)
|
|
index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS)
|
|
size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS)
|
|
return self._inner.check_bounds(index, size, lower, upper)
|
|
|
|
def bucketize(
|
|
self,
|
|
values: T,
|
|
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
|
boundary_indices: T,
|
|
indexing_dtype: torch.dtype,
|
|
right: bool,
|
|
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
|
sorter_indices: Optional[T] = None,
|
|
) -> T:
|
|
"""
|
|
See [Note: Inductor bucketize op]
|
|
"""
|
|
boundaries = (
|
|
boundaries[0],
|
|
self._add_index(
|
|
boundaries[1],
|
|
MemoryUsageType.BUCKETIZE,
|
|
buffer_name=boundaries[0],
|
|
),
|
|
self._add_index(
|
|
boundaries[2],
|
|
MemoryUsageType.BUCKETIZE,
|
|
buffer_name=boundaries[0],
|
|
),
|
|
self._add_index(
|
|
boundaries[3],
|
|
MemoryUsageType.BUCKETIZE,
|
|
buffer_name=boundaries[0],
|
|
),
|
|
)
|
|
if sorter is not None:
|
|
sorter = (
|
|
sorter[0],
|
|
self._add_index(
|
|
sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0]
|
|
),
|
|
)
|
|
|
|
return self._inner.bucketize(
|
|
values,
|
|
boundaries,
|
|
boundary_indices,
|
|
indexing_dtype,
|
|
right,
|
|
sorter,
|
|
sorter_indices,
|
|
)
|
|
|
|
def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
|
"""
|
|
Recursively capture the masked out body in another LoopBodyBlock
|
|
"""
|
|
name = self.body.add_submodule(None, "masked_subblock")
|
|
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
|
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
|
return self.tracer.create_proxy(
|
|
"call_module", name, (mask_proxy, other_proxy), {}
|
|
)
|
|
|
|
def scan(
|
|
self,
|
|
dtype_proxy,
|
|
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
|
value_proxy,
|
|
):
|
|
shim = self.body.bind_scan_shim(combine_fn)
|
|
name = self.body.add_submodule(shim, "scan")
|
|
result = self.tracer.create_proxy(
|
|
"call_module",
|
|
name,
|
|
(dtype_proxy, value_proxy),
|
|
{},
|
|
)
|
|
# Proxies are iterable, but some methods expect tuples/lists
|
|
return tuple(result[i] for i in range(len(value_proxy)))
|
|
|
|
def sort(self, dtypes, values, stable, descending):
|
|
result = self._inner.sort(dtypes, values, stable, descending)
|
|
# Proxies are iterable, but some methods expect tuples/lists
|
|
return tuple(result[i] for i in range(len(values)))
|
|
|
|
def frexp(self, value_proxy):
|
|
result = self._inner.frexp(value_proxy)
|
|
# Proxies are iterable, but some methods expect tuples/lists
|
|
return (result[0], result[1])
|
|
|
|
def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True):
|
|
"""
|
|
Flow data from tensors into indexing formulas.
|
|
Introduce a call_module to update the indexing.
|
|
"""
|
|
|
|
var = self.body.add_indirect(size)
|
|
set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg)
|
|
self.tracer.create_proxy(
|
|
"call_module",
|
|
self.body.add_submodule(set_indirect, f"set_{var}"),
|
|
(index_proxy,),
|
|
{},
|
|
)
|
|
return var
|
|
|
|
def output(self, *result):
|
|
self.tracer.create_proxy("output", "output", result, {})
|