mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
[inductor] Move LoopBody to its own file (#135257)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135257 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
18479c5f70
commit
eac5e12548
@ -9,7 +9,7 @@ from sympy import Expr
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
|
||||
from .ir import InterpreterShim, LoopBody, LoopBodyBlock
|
||||
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
|
||||
from .utils import cache_on_self, dominated_nodes
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class DataTypePropagation:
|
||||
|
||||
@classmethod
|
||||
def propagate_scheduler_node(cls, node):
|
||||
from ..ir import LoopBody
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import SchedulerNode
|
||||
|
||||
assert isinstance(node, SchedulerNode)
|
||||
|
||||
@ -22,6 +22,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import (
|
||||
BaseSchedulerNode,
|
||||
BaseScheduling,
|
||||
@ -3186,7 +3187,7 @@ class CppTile2DKernel(CppVecKernel):
|
||||
)
|
||||
|
||||
|
||||
def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Tuple[Optional[torch.dtype], bool]:
|
||||
def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]:
|
||||
"""
|
||||
Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes
|
||||
and if all the nodes can codegen with this data type without converting to float.
|
||||
@ -3471,7 +3472,7 @@ class CppKernelProxy(CppKernel):
|
||||
|
||||
# Check if all the nodes of a given fx graph can support BF16/FP16
|
||||
def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode):
|
||||
if not isinstance(scheduler_node._body, ir.LoopBody):
|
||||
if not isinstance(scheduler_node._body, LoopBody):
|
||||
return True
|
||||
# Propagate the dtype to check if all the fx node is bf16/fp16
|
||||
DataTypePropagation.propagate_scheduler_node(scheduler_node)
|
||||
@ -3480,7 +3481,7 @@ class CppKernelProxy(CppKernel):
|
||||
and not get_loop_body_lowp_fp(scheduler_node._body)[1]
|
||||
)
|
||||
|
||||
def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody):
|
||||
def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody):
|
||||
def add_to_dtype(sub_graph: torch.fx.Graph):
|
||||
def is_lowp_fp_load(node: torch.fx.Node):
|
||||
if node.target not in ["load"]:
|
||||
@ -3647,7 +3648,7 @@ class CppKernelProxy(CppKernel):
|
||||
|
||||
for _node in nodes:
|
||||
assert isinstance(_node, SchedulerNode)
|
||||
assert isinstance(_node._body, ir.LoopBody)
|
||||
assert isinstance(_node._body, LoopBody)
|
||||
node: SchedulerNode = _node
|
||||
|
||||
def is_memory_copy_scheduler_node(node: SchedulerNode):
|
||||
@ -3658,7 +3659,7 @@ class CppKernelProxy(CppKernel):
|
||||
|
||||
should_legalize = not is_memory_copy_scheduler_node(node)
|
||||
if should_legalize:
|
||||
body: ir.LoopBody = node._body
|
||||
body: LoopBody = node._body
|
||||
self.legalize_lowp_fp_dtype_loopbody(body)
|
||||
|
||||
def codegen_functions(self, fn_list, var_sizes_list):
|
||||
|
||||
@ -10,6 +10,7 @@ from torch.utils._sympy.symbol import SymT
|
||||
|
||||
from .. import config, cpp_builder, ir, lowering as L
|
||||
from ..autotune_process import CppBenchmarkRequest
|
||||
from ..loop_body import LoopBody
|
||||
from ..select_algorithm import PartialRender
|
||||
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
|
||||
from ..virtualized import V
|
||||
@ -239,7 +240,7 @@ class CppTemplateKernel(CppKernel):
|
||||
node.make_loader()(new_args).value,
|
||||
)
|
||||
|
||||
body = ir.LoopBody(
|
||||
body = LoopBody(
|
||||
fn,
|
||||
(list(var_ranges.keys()), ()),
|
||||
var_ranges,
|
||||
|
||||
@ -16,6 +16,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .. import ir
|
||||
from ..loop_body import LoopBody
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from ..virtualized import ops, OpsValue, V
|
||||
from .common import (
|
||||
@ -883,20 +884,19 @@ def create_epilogue_with_attr(input_buffer, attr, **kwargs):
|
||||
|
||||
|
||||
def _get_loop_body(fn_list):
|
||||
loop_bodies = None
|
||||
if all(isinstance(fn, ir.LoopBody) for fn in fn_list):
|
||||
if all(isinstance(fn, LoopBody) for fn in fn_list):
|
||||
loop_bodies = fn_list
|
||||
else:
|
||||
if hasattr(fn_list[0], "original_fn"):
|
||||
# For the case of local buffer, we wrap the fn with localize_function
|
||||
assert all(hasattr(fn, "original_fn") for fn in fn_list)
|
||||
assert all(
|
||||
isinstance(fn.original_fn.args[0]._body, ir.LoopBody) for fn in fn_list
|
||||
isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list
|
||||
)
|
||||
loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list]
|
||||
else:
|
||||
assert all(isinstance(fn, functools.partial) for fn in fn_list)
|
||||
assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list)
|
||||
assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list)
|
||||
loop_bodies = [fn.args[0]._body for fn in fn_list]
|
||||
assert loop_bodies is not None
|
||||
return loop_bodies
|
||||
|
||||
@ -6,7 +6,6 @@ import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
@ -72,6 +71,7 @@ from .dependencies import (
|
||||
extract_read_writes,
|
||||
var_builder,
|
||||
)
|
||||
from .loop_body import LoopBody
|
||||
from .ops_handler import OpCounterCSE
|
||||
from .runtime.benchmarking import benchmarker
|
||||
from .runtime.hints import ReductionHint
|
||||
@ -6770,505 +6770,6 @@ class TorchBindObject(IRNode):
|
||||
return self.name
|
||||
|
||||
|
||||
class InterpreterShim(torch.fx.Interpreter):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _dummy_gm():
|
||||
return torch.fx.symbolic_trace(identity)
|
||||
|
||||
def __init__(self, graph, submodules):
|
||||
# call super() with a placeholder to avoid constructing a
|
||||
# GraphModule which is very expensive (it does codegen).
|
||||
super().__init__(self._dummy_gm(), garbage_collect_values=False)
|
||||
self.module = self # type: ignore[assignment]
|
||||
self.graph = graph
|
||||
self.submodules = submodules
|
||||
self.extra_traceback = False
|
||||
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
|
||||
self.current_node = None
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
self.current_node = n
|
||||
return super().run_node(n)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with V.set_interpreter_handler(self):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class LoopBody:
|
||||
"""
|
||||
Captures the body of a Loops subclass into an FX graph. Persists any
|
||||
indexing simplifications and makes it easier to analyze loop bodies.
|
||||
"""
|
||||
|
||||
indexing_exprs: Dict[str, sympy.Expr]
|
||||
indexing_exprs_name: Dict[sympy.Expr, str]
|
||||
reads_name2expr: Dict[str, sympy.Expr]
|
||||
writes_name2expr: Dict[str, sympy.Expr]
|
||||
submodules: Dict[str, Any]
|
||||
subblocks: Dict[str, LoopBodyBlock]
|
||||
indirect_vars: List[str]
|
||||
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
|
||||
root_block: LoopBodyBlock
|
||||
|
||||
def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
|
||||
super().__init__()
|
||||
|
||||
_flat_sizes = tuple(var_ranges.values())
|
||||
self.sizes = (
|
||||
_flat_sizes[: len(iter_vars)],
|
||||
_flat_sizes[len(iter_vars) :],
|
||||
)
|
||||
|
||||
self.iter_vars = iter_vars
|
||||
self.reduce_vars = reduce_vars
|
||||
self.var_ranges = var_ranges
|
||||
|
||||
if isinstance(fn, LoopBody):
|
||||
self._init_with_copy(fn, args)
|
||||
else:
|
||||
self._init_with_tracing(fn, args)
|
||||
|
||||
self.indexing = None
|
||||
|
||||
def _init_with_tracing(self, fn, args):
|
||||
"""Do an FX trace of an arbitrary callable to construct self"""
|
||||
self.indexing_exprs = {}
|
||||
self.indexing_exprs_name = {}
|
||||
self.reads_name2expr = {}
|
||||
self.writes_name2expr = {}
|
||||
self.submodules = {"get_index": self.get_index}
|
||||
self.subblocks = {}
|
||||
self.indirect_vars = []
|
||||
self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
|
||||
self.root_block = LoopBodyBlock(self, fn, args) # traces
|
||||
del self.indexing_exprs_name # not used after _init_with_tracing
|
||||
|
||||
def _init_with_copy(self, other: LoopBody, args):
|
||||
"""
|
||||
_init_with_tracing() is slow, so this is a fast path in the case
|
||||
where we are just reordering/merging/splitting the args of an
|
||||
existing LoopBody.
|
||||
"""
|
||||
indexing_exprs = other.indexing_from_args(args)
|
||||
update_index = {
|
||||
expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items()
|
||||
}
|
||||
assert indexing_exprs.keys() == other.indexing_exprs.keys() and len(
|
||||
update_index
|
||||
) == len(indexing_exprs)
|
||||
|
||||
self.indexing_exprs = indexing_exprs
|
||||
self.reads_name2expr = {
|
||||
k: update_index[v] for k, v in other.reads_name2expr.items()
|
||||
}
|
||||
self.writes_name2expr = {
|
||||
k: update_index[v] for k, v in other.writes_name2expr.items()
|
||||
}
|
||||
self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()}
|
||||
self.indirect_vars = [*other.indirect_vars]
|
||||
self.indirect_var_ranges = {**other.indirect_var_ranges}
|
||||
self.root_block = other.root_block.clone(self)
|
||||
|
||||
submodules = {**other.submodules}
|
||||
submodules.pop("get_index")
|
||||
self.submodules = {
|
||||
"get_index": self.get_index,
|
||||
**{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined]
|
||||
}
|
||||
|
||||
def merge_loops(self) -> LoopBody:
|
||||
"""
|
||||
Merge both iteration and reduction loops and return a new LoopBody.
|
||||
"""
|
||||
old_body = self
|
||||
old_sizes = self.sizes
|
||||
old_iter_vars, old_reduce_vars = old_body.vars
|
||||
old_iter_sizes, old_reduce_sizes = old_sizes
|
||||
|
||||
index_exprs = [*old_body.indexing_exprs.values()]
|
||||
|
||||
iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops(
|
||||
old_iter_vars,
|
||||
old_iter_sizes,
|
||||
index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes),
|
||||
)
|
||||
|
||||
reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops(
|
||||
old_reduce_vars,
|
||||
old_reduce_sizes,
|
||||
index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes),
|
||||
)
|
||||
|
||||
# if iter_sizes == old_iter_sizes:
|
||||
# # no dimensions get merged.
|
||||
# return old_sizes, old_body
|
||||
|
||||
# Note: if no dimension get merges, the symbol prefix will
|
||||
# remain 'y'. But if we merge dimensions, we change prefix to
|
||||
# 'z'. If this is an issue, we can always retrace the LoopBody
|
||||
# to change symbol prefix to 'z'.
|
||||
#
|
||||
# There is indeed an issue due to symbol name conflicting.
|
||||
# y0 maybe reused for the y dimension later.
|
||||
(
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="t"
|
||||
)
|
||||
new_body = LoopBody(
|
||||
old_body,
|
||||
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
||||
var_ranges,
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
)
|
||||
|
||||
# use the original symbol prefix
|
||||
# Can try to optimize if this is a bottleneck for compilation time
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="z"
|
||||
)
|
||||
new_body2 = LoopBody(
|
||||
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
)
|
||||
return new_body2
|
||||
|
||||
def reorder_iter_loops(self, new_order) -> LoopBody:
|
||||
"""
|
||||
Reorder iteration loops and return a new LoopBody.
|
||||
"""
|
||||
old_body = self
|
||||
old_sizes = self.sizes
|
||||
assert len(old_sizes[0]) == len(new_order)
|
||||
reorder_fn = same_reorder(new_order)
|
||||
|
||||
iter_size, reduce_size = old_sizes
|
||||
new_iter_size = reorder_fn(iter_size)
|
||||
|
||||
new_sizes = (new_iter_size, reduce_size)
|
||||
|
||||
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="t" # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
inverse_order = {b: a for a, b in enumerate(new_order)}
|
||||
inverse_order = [inverse_order[i] for i in range(len(new_order))]
|
||||
|
||||
def new_body(*indices: Sequence[sympy.Expr]) -> Any:
|
||||
index = list(itertools.chain(*indices))
|
||||
assert len(index) == len(iter_size) + len(reduce_size)
|
||||
iter_idx = index[: len(iter_size)]
|
||||
reduce_idx = index[len(iter_size) :]
|
||||
iter_idx = [iter_idx[i] for i in inverse_order]
|
||||
return old_body(iter_idx, reduce_idx)
|
||||
|
||||
loop_body = LoopBody(
|
||||
new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars
|
||||
)
|
||||
|
||||
# use the original symbol prefix so we can do multiple round of reordering
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="z" # type: ignore[arg-type]
|
||||
)
|
||||
new_body = LoopBody(
|
||||
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
)
|
||||
return new_body
|
||||
|
||||
@property
|
||||
def vars(self):
|
||||
assert self.iter_vars is not None
|
||||
assert self.reduce_vars is not None
|
||||
return self.iter_vars, self.reduce_vars
|
||||
|
||||
@cache_on_self
|
||||
def get_nodes(self):
|
||||
all_graphs = itertools.chain(
|
||||
(self.root_block.graph,),
|
||||
(block.graph for block in self.subblocks.values()),
|
||||
)
|
||||
return [node for graph in all_graphs for node in graph.nodes]
|
||||
|
||||
@cache_on_self
|
||||
def bounds(self):
|
||||
# Doing a local import to avoid dumping all the code here
|
||||
from .bounds import BoundVars
|
||||
|
||||
return BoundVars(self)
|
||||
|
||||
def debug_str(self):
|
||||
lines = [f"var_ranges = {dict(self.var_ranges)}"]
|
||||
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
|
||||
lines.extend(
|
||||
[
|
||||
block.debug_str(name)
|
||||
for name, block in itertools.chain(
|
||||
[("body", self.root_block)], self.subblocks.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
__repr__ = debug_str
|
||||
|
||||
def add_index_expr(self, expr: sympy.Expr, category, buf_name):
|
||||
if buf_name is not None:
|
||||
getattr(self, f"{category}_name2expr")[buf_name] = expr
|
||||
if expr not in self.indexing_exprs_name:
|
||||
name = f"index{len(self.indexing_exprs)}"
|
||||
self.indexing_exprs_name[expr] = name
|
||||
self.indexing_exprs[name] = expr
|
||||
return self.indexing_exprs_name[expr]
|
||||
|
||||
def add_submodule(self, block, prefix):
|
||||
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
|
||||
if prefix[-1].isnumeric() and prefix not in self.submodules:
|
||||
name = prefix
|
||||
else:
|
||||
name = f"{prefix}{len(self.submodules)}"
|
||||
self.submodules[name] = block
|
||||
return name
|
||||
|
||||
def add_indirect(self, size):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
|
||||
assert var not in self.indirect_var_ranges
|
||||
self.indirect_vars.append(var)
|
||||
self.indirect_var_ranges[var] = size
|
||||
return var
|
||||
|
||||
def replace_indirect(self, old, new):
|
||||
"""Swap in a variable used in indirect indexing"""
|
||||
if str(old) == str(new):
|
||||
return
|
||||
assert self.indexing is not None
|
||||
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
|
||||
|
||||
def get_index(self, name):
|
||||
assert self.indexing is not None
|
||||
return self.indexing[name]
|
||||
|
||||
def indexing_from_args(self, indices):
|
||||
index = [*itertools.chain.from_iterable(indices)]
|
||||
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
||||
assert all(
|
||||
v not in self.var_ranges for v in index
|
||||
), f"{self.var_ranges=}, {indices=}"
|
||||
replacements = dict(zip(self.var_ranges.keys(), index))
|
||||
return {
|
||||
name: sympy_subs(expr, replacements)
|
||||
for name, expr in self.indexing_exprs.items()
|
||||
}
|
||||
|
||||
def __call__(self, *indices):
|
||||
self.indexing = self.indexing_from_args(indices)
|
||||
result = self.root_block()
|
||||
self.indexing = None
|
||||
return result
|
||||
|
||||
def bind_set_indirect_shim(self, var, size, check, wrap_neg):
|
||||
def set_indirect(new_var):
|
||||
self.replace_indirect(
|
||||
var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
|
||||
)
|
||||
|
||||
set_indirect.clone = functools.partial( # type: ignore[attr-defined]
|
||||
LoopBody.bind_set_indirect_shim,
|
||||
var=var,
|
||||
size=size,
|
||||
check=check,
|
||||
wrap_neg=wrap_neg,
|
||||
)
|
||||
return set_indirect
|
||||
|
||||
def bind_scan_shim(self, combine_fn):
|
||||
def shim(dtypes, values):
|
||||
return V.ops.scan(dtypes, combine_fn, values)
|
||||
|
||||
shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined]
|
||||
return shim
|
||||
|
||||
def bind_masked_shim(self, name):
|
||||
def shim(mask, other):
|
||||
return V.ops.masked(mask, self.subblocks[name], other)
|
||||
|
||||
shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined]
|
||||
return shim
|
||||
|
||||
|
||||
class LoopBodyBlock:
|
||||
"""
|
||||
Captures the body of a Loops subclass into an FX graph.
|
||||
In normal cases there will be a 1:1 mapping between LoopBody and
|
||||
LoopBodyBlock, hower in the case of ops.masked() the masked out
|
||||
operations will manifest as an extra LoopBodyBlock.
|
||||
"""
|
||||
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
|
||||
self.body = body
|
||||
|
||||
def add_index(expr, category, buf_name=None):
|
||||
return tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(self.body.add_index_expr(expr, category, buf_name),),
|
||||
{},
|
||||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = add_index(index, "reads", name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = add_index(index, "writes", name)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = add_index(index, "writes", name)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = add_index(index, "other")
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = add_index(index, "other")
|
||||
size = add_index(size, "other")
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
):
|
||||
offsets_size = add_index(offsets_size, "other")
|
||||
return self._inner.bucketize(
|
||||
values, offsets_name, offsets_size, indexing_dtype, right
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scan(
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[
|
||||
[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]
|
||||
],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(
|
||||
var, size, check, wrap_neg
|
||||
)
|
||||
tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
@staticmethod
|
||||
def output(result):
|
||||
tracer.create_proxy("output", "output", (result,), {})
|
||||
|
||||
tracer = torch.fx.Tracer()
|
||||
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
|
||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
|
||||
from .index_propagation import IndexPropagation
|
||||
from .sizevars import SimplifyIndexing
|
||||
|
||||
handler: Any = SimplifyIndexing(
|
||||
CaptureIndexing(proxy_ops), self.body.var_ranges
|
||||
)
|
||||
if config.constant_and_index_propagation:
|
||||
handler = IndexPropagation(
|
||||
handler, self.body.var_ranges, self.body.indirect_var_ranges
|
||||
)
|
||||
|
||||
with V.set_ops_handler(handler):
|
||||
# This indirection is just a cute way to get IndexPropagation to
|
||||
# unwrap the return value.
|
||||
ops.output(fn(*args))
|
||||
self.graph = tracer.graph
|
||||
|
||||
def __call__(self):
|
||||
graph = self.graph
|
||||
submodules = self.body.submodules
|
||||
|
||||
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
|
||||
|
||||
def debug_str(self, name="block"):
|
||||
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
|
||||
return re.sub(
|
||||
# strip `; del var0` suffixes to make output prettier
|
||||
r";[^\n]*",
|
||||
"",
|
||||
code.strip().replace("def forward(", f"def {name}("),
|
||||
)
|
||||
|
||||
def clone(self, body: LoopBody):
|
||||
"""Shallow copy with a new parent LoopBody"""
|
||||
copy = LoopBodyBlock.__new__(LoopBodyBlock)
|
||||
copy.__dict__.update({**self.__dict__, "body": body})
|
||||
return copy
|
||||
|
||||
|
||||
class _CollectiveKernel(FallbackKernel):
|
||||
def should_allocate(self):
|
||||
return False
|
||||
|
||||
519
torch/_inductor/loop_body.py
Normal file
519
torch/_inductor/loop_body.py
Normal file
@ -0,0 +1,519 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Sequence, Tuple
|
||||
|
||||
import sympy
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.utils import identity
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
class InterpreterShim(torch.fx.Interpreter):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _dummy_gm():
|
||||
return torch.fx.symbolic_trace(identity)
|
||||
|
||||
def __init__(self, graph, submodules):
|
||||
# call super() with a placeholder to avoid constructing a
|
||||
# GraphModule which is very expensive (it does codegen).
|
||||
super().__init__(self._dummy_gm(), garbage_collect_values=False)
|
||||
self.module = self # type: ignore[assignment]
|
||||
self.graph = graph
|
||||
self.submodules = submodules
|
||||
self.extra_traceback = False
|
||||
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
|
||||
self.current_node = None
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
self.current_node = n
|
||||
return super().run_node(n)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with V.set_interpreter_handler(self):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class LoopBody:
|
||||
"""
|
||||
Captures the body of a Loops subclass into an FX graph. Persists any
|
||||
indexing simplifications and makes it easier to analyze loop bodies.
|
||||
"""
|
||||
|
||||
indexing_exprs: Dict[str, sympy.Expr]
|
||||
indexing_exprs_name: Dict[sympy.Expr, str]
|
||||
reads_name2expr: Dict[str, sympy.Expr]
|
||||
writes_name2expr: Dict[str, sympy.Expr]
|
||||
submodules: Dict[str, Any]
|
||||
subblocks: Dict[str, LoopBodyBlock]
|
||||
indirect_vars: List[str]
|
||||
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
|
||||
root_block: LoopBodyBlock
|
||||
|
||||
def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
|
||||
super().__init__()
|
||||
|
||||
_flat_sizes = tuple(var_ranges.values())
|
||||
self.sizes = (
|
||||
_flat_sizes[: len(iter_vars)],
|
||||
_flat_sizes[len(iter_vars) :],
|
||||
)
|
||||
|
||||
self.iter_vars = iter_vars
|
||||
self.reduce_vars = reduce_vars
|
||||
self.var_ranges = var_ranges
|
||||
|
||||
if isinstance(fn, LoopBody):
|
||||
self._init_with_copy(fn, args)
|
||||
else:
|
||||
self._init_with_tracing(fn, args)
|
||||
|
||||
self.indexing = None
|
||||
|
||||
def _init_with_tracing(self, fn, args):
|
||||
"""Do an FX trace of an arbitrary callable to construct self"""
|
||||
self.indexing_exprs = {}
|
||||
self.indexing_exprs_name = {}
|
||||
self.reads_name2expr = {}
|
||||
self.writes_name2expr = {}
|
||||
self.submodules = {"get_index": self.get_index}
|
||||
self.subblocks = {}
|
||||
self.indirect_vars = []
|
||||
self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
|
||||
self.root_block = LoopBodyBlock(self, fn, args) # traces
|
||||
del self.indexing_exprs_name # not used after _init_with_tracing
|
||||
|
||||
def _init_with_copy(self, other: LoopBody, args):
|
||||
"""
|
||||
_init_with_tracing() is slow, so this is a fast path in the case
|
||||
where we are just reordering/merging/splitting the args of an
|
||||
existing LoopBody.
|
||||
"""
|
||||
indexing_exprs = other.indexing_from_args(args)
|
||||
update_index = {
|
||||
expr: indexing_exprs[name] for name, expr in other.indexing_exprs.items()
|
||||
}
|
||||
assert indexing_exprs.keys() == other.indexing_exprs.keys() and len(
|
||||
update_index
|
||||
) == len(indexing_exprs)
|
||||
|
||||
self.indexing_exprs = indexing_exprs
|
||||
self.reads_name2expr = {
|
||||
k: update_index[v] for k, v in other.reads_name2expr.items()
|
||||
}
|
||||
self.writes_name2expr = {
|
||||
k: update_index[v] for k, v in other.writes_name2expr.items()
|
||||
}
|
||||
self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()}
|
||||
self.indirect_vars = [*other.indirect_vars]
|
||||
self.indirect_var_ranges = {**other.indirect_var_ranges}
|
||||
self.root_block = other.root_block.clone(self)
|
||||
|
||||
submodules = {**other.submodules}
|
||||
submodules.pop("get_index")
|
||||
self.submodules = {
|
||||
"get_index": self.get_index,
|
||||
**{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined]
|
||||
}
|
||||
|
||||
def merge_loops(self) -> LoopBody:
|
||||
"""
|
||||
Merge both iteration and reduction loops and return a new LoopBody.
|
||||
"""
|
||||
old_body = self
|
||||
old_sizes = self.sizes
|
||||
old_iter_vars, old_reduce_vars = old_body.vars
|
||||
old_iter_sizes, old_reduce_sizes = old_sizes
|
||||
|
||||
index_exprs = [*old_body.indexing_exprs.values()]
|
||||
|
||||
iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops(
|
||||
old_iter_vars,
|
||||
old_iter_sizes,
|
||||
index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes),
|
||||
)
|
||||
|
||||
reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops(
|
||||
old_reduce_vars,
|
||||
old_reduce_sizes,
|
||||
index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes),
|
||||
)
|
||||
|
||||
# if iter_sizes == old_iter_sizes:
|
||||
# # no dimensions get merged.
|
||||
# return old_sizes, old_body
|
||||
|
||||
# Note: if no dimension get merges, the symbol prefix will
|
||||
# remain 'y'. But if we merge dimensions, we change prefix to
|
||||
# 'z'. If this is an issue, we can always retrace the LoopBody
|
||||
# to change symbol prefix to 'z'.
|
||||
#
|
||||
# There is indeed an issue due to symbol name conflicting.
|
||||
# y0 maybe reused for the y dimension later.
|
||||
(
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="t"
|
||||
)
|
||||
new_body = LoopBody(
|
||||
old_body,
|
||||
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
||||
var_ranges,
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
)
|
||||
|
||||
# use the original symbol prefix
|
||||
# Can try to optimize if this is a bottleneck for compilation time
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="z"
|
||||
)
|
||||
new_body2 = LoopBody(
|
||||
new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
)
|
||||
return new_body2
|
||||
|
||||
def reorder_iter_loops(self, new_order) -> LoopBody:
|
||||
"""
|
||||
Reorder iteration loops and return a new LoopBody.
|
||||
"""
|
||||
from .ir import same_reorder
|
||||
|
||||
old_body = self
|
||||
old_sizes = self.sizes
|
||||
assert len(old_sizes[0]) == len(new_order)
|
||||
reorder_fn = same_reorder(new_order)
|
||||
|
||||
iter_size, reduce_size = old_sizes
|
||||
new_iter_size = reorder_fn(iter_size)
|
||||
|
||||
new_sizes = (new_iter_size, reduce_size)
|
||||
|
||||
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="t" # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
inverse_order = {b: a for a, b in enumerate(new_order)}
|
||||
inverse_order = [inverse_order[i] for i in range(len(new_order))]
|
||||
|
||||
def new_body(*indices: Sequence[sympy.Expr]) -> Any:
|
||||
index = list(itertools.chain(*indices))
|
||||
assert len(index) == len(iter_size) + len(reduce_size)
|
||||
iter_idx = index[: len(iter_size)]
|
||||
reduce_idx = index[len(iter_size) :]
|
||||
iter_idx = [iter_idx[i] for i in inverse_order]
|
||||
return old_body(iter_idx, reduce_idx)
|
||||
|
||||
loop_body = LoopBody(
|
||||
new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars
|
||||
)
|
||||
|
||||
# use the original symbol prefix so we can do multiple round of reordering
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="z" # type: ignore[arg-type]
|
||||
)
|
||||
new_body = LoopBody(
|
||||
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
)
|
||||
return new_body
|
||||
|
||||
@property
|
||||
def vars(self):
|
||||
assert self.iter_vars is not None
|
||||
assert self.reduce_vars is not None
|
||||
return self.iter_vars, self.reduce_vars
|
||||
|
||||
@cache_on_self
|
||||
def get_nodes(self):
|
||||
all_graphs = itertools.chain(
|
||||
(self.root_block.graph,),
|
||||
(block.graph for block in self.subblocks.values()),
|
||||
)
|
||||
return [node for graph in all_graphs for node in graph.nodes]
|
||||
|
||||
@cache_on_self
|
||||
def bounds(self):
|
||||
# Doing a local import to avoid dumping all the code here
|
||||
from .bounds import BoundVars
|
||||
|
||||
return BoundVars(self)
|
||||
|
||||
def debug_str(self):
|
||||
lines = [f"var_ranges = {dict(self.var_ranges)}"]
|
||||
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
|
||||
lines.extend(
|
||||
[
|
||||
block.debug_str(name)
|
||||
for name, block in itertools.chain(
|
||||
[("body", self.root_block)], self.subblocks.items()
|
||||
)
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
__repr__ = debug_str
|
||||
|
||||
def add_index_expr(self, expr: sympy.Expr, category, buf_name):
|
||||
if buf_name is not None:
|
||||
getattr(self, f"{category}_name2expr")[buf_name] = expr
|
||||
if expr not in self.indexing_exprs_name:
|
||||
name = f"index{len(self.indexing_exprs)}"
|
||||
self.indexing_exprs_name[expr] = name
|
||||
self.indexing_exprs[name] = expr
|
||||
return self.indexing_exprs_name[expr]
|
||||
|
||||
def add_submodule(self, block, prefix):
|
||||
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
|
||||
if prefix[-1].isnumeric() and prefix not in self.submodules:
|
||||
name = prefix
|
||||
else:
|
||||
name = f"{prefix}{len(self.submodules)}"
|
||||
self.submodules[name] = block
|
||||
return name
|
||||
|
||||
def add_indirect(self, size):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
|
||||
assert var not in self.indirect_var_ranges
|
||||
self.indirect_vars.append(var)
|
||||
self.indirect_var_ranges[var] = size
|
||||
return var
|
||||
|
||||
def replace_indirect(self, old, new):
|
||||
"""Swap in a variable used in indirect indexing"""
|
||||
if str(old) == str(new):
|
||||
return
|
||||
assert self.indexing is not None
|
||||
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
|
||||
|
||||
def get_index(self, name):
|
||||
assert self.indexing is not None
|
||||
return self.indexing[name]
|
||||
|
||||
def indexing_from_args(self, indices):
|
||||
index = [*itertools.chain.from_iterable(indices)]
|
||||
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
||||
assert all(
|
||||
v not in self.var_ranges for v in index
|
||||
), f"{self.var_ranges=}, {indices=}"
|
||||
replacements = dict(zip(self.var_ranges.keys(), index))
|
||||
return {
|
||||
name: sympy_subs(expr, replacements)
|
||||
for name, expr in self.indexing_exprs.items()
|
||||
}
|
||||
|
||||
def __call__(self, *indices):
|
||||
self.indexing = self.indexing_from_args(indices)
|
||||
result = self.root_block()
|
||||
self.indexing = None
|
||||
return result
|
||||
|
||||
def bind_set_indirect_shim(self, var, size, check, wrap_neg):
|
||||
def set_indirect(new_var):
|
||||
self.replace_indirect(
|
||||
var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
|
||||
)
|
||||
|
||||
set_indirect.clone = functools.partial( # type: ignore[attr-defined]
|
||||
LoopBody.bind_set_indirect_shim,
|
||||
var=var,
|
||||
size=size,
|
||||
check=check,
|
||||
wrap_neg=wrap_neg,
|
||||
)
|
||||
return set_indirect
|
||||
|
||||
def bind_scan_shim(self, combine_fn):
|
||||
def shim(dtypes, values):
|
||||
return V.ops.scan(dtypes, combine_fn, values)
|
||||
|
||||
shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined]
|
||||
return shim
|
||||
|
||||
def bind_masked_shim(self, name):
|
||||
def shim(mask, other):
|
||||
return V.ops.masked(mask, self.subblocks[name], other)
|
||||
|
||||
shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined]
|
||||
return shim
|
||||
|
||||
|
||||
class LoopBodyBlock:
|
||||
"""
|
||||
Captures the body of a Loops subclass into an FX graph.
|
||||
In normal cases there will be a 1:1 mapping between LoopBody and
|
||||
LoopBodyBlock, hower in the case of ops.masked() the masked out
|
||||
operations will manifest as an extra LoopBodyBlock.
|
||||
"""
|
||||
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
|
||||
self.body = body
|
||||
|
||||
def add_index(expr, category, buf_name=None):
|
||||
return tracer.create_proxy(
|
||||
"call_module",
|
||||
"get_index",
|
||||
(self.body.add_index_expr(expr, category, buf_name),),
|
||||
{},
|
||||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = add_index(index, "reads", name)
|
||||
return self._inner.load(name, index)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
index = add_index(index, "writes", name)
|
||||
return self._inner.store(name, index, value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
index = add_index(index, "writes", name)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
if "welford" in reduction_type:
|
||||
return tuple(result[i] for i in range(3))
|
||||
return result
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
if isinstance(index, (int, sympy.Integer)):
|
||||
return self._inner.constant(int(index), dtype)
|
||||
index = add_index(index, "other")
|
||||
return self._inner.index_expr(index, dtype)
|
||||
|
||||
def check_bounds(self, index, size, lower, upper):
|
||||
index = add_index(index, "other")
|
||||
size = add_index(size, "other")
|
||||
return self._inner.check_bounds(index, size, lower, upper)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values,
|
||||
offsets_name: str,
|
||||
offsets_size: sympy.Expr,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
):
|
||||
offsets_size = add_index(offsets_size, "other")
|
||||
return self._inner.bucketize(
|
||||
values, offsets_name, offsets_size, indexing_dtype, right
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
|
||||
"""
|
||||
Recursively capture the masked out body in another LoopBodyBlock
|
||||
"""
|
||||
name = self.body.add_submodule(None, "masked_subblock")
|
||||
self.body.submodules[name] = self.body.bind_masked_shim(name)
|
||||
self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
|
||||
return tracer.create_proxy(
|
||||
"call_module", name, (mask_proxy, other_proxy), {}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scan(
|
||||
dtype_proxy,
|
||||
combine_fn: Callable[
|
||||
[Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]
|
||||
],
|
||||
value_proxy,
|
||||
):
|
||||
shim = self.body.bind_scan_shim(combine_fn)
|
||||
name = self.body.add_submodule(shim, "scan")
|
||||
result = tracer.create_proxy(
|
||||
"call_module",
|
||||
name,
|
||||
(dtype_proxy, value_proxy),
|
||||
{},
|
||||
)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(value_proxy)))
|
||||
|
||||
def sort(self, dtypes, values, stable, descending):
|
||||
result = self._inner.sort(dtypes, values, stable, descending)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return tuple(result[i] for i in range(len(values)))
|
||||
|
||||
def frexp(self, value_proxy):
|
||||
result = self._inner.frexp(value_proxy)
|
||||
# Proxies are iterable, but some methods expect tuples/lists
|
||||
return (result[0], result[1])
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
|
||||
"""
|
||||
Flow data from tensors into indexing formulas.
|
||||
Introduce a call_module to update the indexing.
|
||||
"""
|
||||
|
||||
var = self.body.add_indirect(size)
|
||||
set_indirect = self.body.bind_set_indirect_shim(
|
||||
var, size, check, wrap_neg
|
||||
)
|
||||
tracer.create_proxy(
|
||||
"call_module",
|
||||
self.body.add_submodule(set_indirect, f"set_{var}"),
|
||||
(index_proxy,),
|
||||
{},
|
||||
)
|
||||
return var
|
||||
|
||||
@staticmethod
|
||||
def output(result):
|
||||
tracer.create_proxy("output", "output", (result,), {})
|
||||
|
||||
tracer = torch.fx.Tracer()
|
||||
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
|
||||
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
|
||||
from .index_propagation import IndexPropagation
|
||||
from .sizevars import SimplifyIndexing
|
||||
|
||||
handler: Any = SimplifyIndexing(
|
||||
CaptureIndexing(proxy_ops), self.body.var_ranges
|
||||
)
|
||||
if config.constant_and_index_propagation:
|
||||
handler = IndexPropagation(
|
||||
handler, self.body.var_ranges, self.body.indirect_var_ranges
|
||||
)
|
||||
|
||||
with V.set_ops_handler(handler):
|
||||
# This indirection is just a cute way to get IndexPropagation to
|
||||
# unwrap the return value.
|
||||
ops.output(fn(*args))
|
||||
self.graph = tracer.graph
|
||||
|
||||
def __call__(self):
|
||||
graph = self.graph
|
||||
submodules = self.body.submodules
|
||||
|
||||
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
|
||||
|
||||
def debug_str(self, name="block"):
|
||||
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
|
||||
return re.sub(
|
||||
# strip `; del var0` suffixes to make output prettier
|
||||
r";[^\n]*",
|
||||
"",
|
||||
code.strip().replace("def forward(", f"def {name}("),
|
||||
)
|
||||
|
||||
def clone(self, body: LoopBody):
|
||||
"""Shallow copy with a new parent LoopBody"""
|
||||
copy = LoopBodyBlock.__new__(LoopBodyBlock)
|
||||
copy.__dict__.update({**self.__dict__, "body": body})
|
||||
return copy
|
||||
@ -6,7 +6,7 @@ import sympy
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from .ir import LoopBody
|
||||
from .loop_body import LoopBody
|
||||
from .utils import dominated_nodes
|
||||
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
||||
from .comm_analysis import estimate_nccl_collective_runtime
|
||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout
|
||||
from .loop_body import LoopBody
|
||||
from .runtime.runtime_utils import green_text, red_text
|
||||
from .sizevars import SimplifyIndexing
|
||||
from .utils import (
|
||||
@ -922,7 +923,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
buf_name = dep.name
|
||||
buf = V.graph.get_buffer(buf_name)
|
||||
lines.append(f"{buf_name}_layout = {pformat(buf.layout)}")
|
||||
if isinstance(self._body, ir.LoopBody):
|
||||
if isinstance(self._body, LoopBody):
|
||||
lines.append(f"class {name}_loop_body:")
|
||||
lines.append(textwrap.indent(self._body.debug_str(), " "))
|
||||
|
||||
@ -1011,7 +1012,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
@cache_on_self
|
||||
def _get_atomic_add_buffers(self) -> OrderedSet[str]:
|
||||
buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
|
||||
if isinstance(self._body, ir.LoopBody):
|
||||
if isinstance(self._body, LoopBody):
|
||||
for node in self._body.get_nodes():
|
||||
if (
|
||||
node.op == "call_method"
|
||||
|
||||
@ -76,7 +76,7 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
||||
from torch._inductor.debug import DebugContext
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import InterpreterShim
|
||||
from torch._inductor.loop_body import InterpreterShim
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
threadlocal = local()
|
||||
|
||||
Reference in New Issue
Block a user