Files
pytorch/torch/_inductor/optimize_indexing.py

577 lines
19 KiB
Python

import dataclasses
import functools
import itertools
import logging
import math
import operator
from typing import Dict, Iterable, Union
import sympy
import torch
from .ir import FloorDiv, InterpreterShim, LoopBody, ModularIndexing
from .utils import sympy_subs
from .virtualized import V
log = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class ValueRanges:
lower: Union[sympy.Expr, sympy.Number, int, float, bool]
upper: Union[sympy.Expr, sympy.Number, int, float, bool]
def __contains__(self, x):
# TODO This needs to be generalised if lower/upper are sympy.Expr
assert not isinstance(x, sympy.Expr)
return self.lower <= x <= self.upper
@classmethod
def wrap(cls, arg):
if isinstance(arg, ValueRanges):
return arg
assert isinstance(arg, (int, float, bool))
return ValueRanges(arg, arg)
@classmethod
def increasing_map(cls, x, fn):
"""map lower and upper bound with fn"""
x = cls.wrap(x)
return ValueRanges(fn(x.lower), fn(x.upper))
@classmethod
def decreasing_map(cls, x, fn):
"""map lower bound to upper bound and upper bound to lower bound"""
x = cls.wrap(x)
return ValueRanges(fn(x.upper), fn(x.lower))
@classmethod
def monotone_map(cls, x, fn):
"""check the max and min of computed upper and lower bound for the output"""
x = cls.wrap(x)
l = fn(x.lower)
u = fn(x.upper)
return ValueRanges(min(l, u), max(l, u))
@classmethod
def convex_min_zero_map(cls, x, fn):
"""the max is at one of the ends"""
x = ValueRanges.wrap(x)
if 0 in x:
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
else:
return cls.monotone_map(x, fn)
@classmethod
def coordinatewise_increasing_map(cls, x, y, fn):
"""map upper and lower bounds accessing corresponding values of inputs"""
x, y = cls.wrap(x), cls.wrap(y)
return ValueRanges(
fn(x.lower, y.lower),
fn(x.upper, y.upper),
)
@classmethod
def coordinatewise_monotone_map(cls, x, y, fn):
"""compute the product of all lower and upper bounds and take min and max"""
x, y = cls.wrap(x), cls.wrap(y)
products = [
fn(a, b)
for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
]
return ValueRanges(min(products), max(products))
class ValueRangeAnalysis:
def __init__(self):
self.name = "ValueRangeAnalysis"
boolean_operators = (
"eq",
"ne",
"lt",
"gt",
"le",
"ge",
"and_",
"or_",
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args, **kwargs):
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true)
@staticmethod
def default_handler(*args, **kwargs):
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges(-math.inf, math.inf)
def load(self, name: str, index: sympy.Expr):
return ValueRanges(-math.inf, math.inf)
def store(self, name, index, value, mode=None):
return
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
return ValueRanges(-math.inf, math.inf)
def index_expr(self, index, dtype):
assert isinstance(index, ValueRanges)
return index
@staticmethod
def to_dtype(x, dtype: torch.dtype):
def is_bool(val):
return isinstance(val, bool) or (
hasattr(val, "is_Boolean") and val.is_Boolean
)
x = ValueRanges.wrap(x)
low, up = x.lower, x.upper
if is_bool(low):
assert is_bool(up)
if dtype.is_floating_point:
return ValueRanges(sympy.Float(0.0), sympy.Float(1.0))
else:
return ValueRanges(sympy.Integer(0), sympy.Integer(1))
return ValueRanges.wrap(x)
@staticmethod
def constant(value, dtype):
# using nan makes subsequent computation throw, and for the purposes of optimization
# returning -math.inf - math.inf is equivalent to giving up
if math.isnan(value):
return ValueRanges(-math.inf, math.inf)
if isinstance(value, int):
return ValueRanges(sympy.Integer(value), sympy.Integer(value))
else:
return ValueRanges(sympy.Float(value), sympy.Float(value))
@staticmethod
def reciprocal(x):
x = ValueRanges.wrap(x)
if 0 in x:
return ValueRanges(-math.inf, math.inf)
else:
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
@staticmethod
def square(x):
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
@staticmethod
def abs(x):
return ValueRanges.convex_min_zero_map(x, abs)
@staticmethod
def neg(x):
return ValueRanges.decreasing_map(x, operator.neg)
@staticmethod
def truediv(a, b):
b = ValueRanges.wrap(b)
if 0 in b:
return ValueRanges(-math.inf, math.inf)
else:
return ValueRangeAnalysis.mul(a, ValueRanges(1 / b.upper, 1 / b.lower))
@staticmethod
def div(a, b):
# We think of this as floor(a / b)
out = ValueRangeAnalysis.truediv(a, b)
return ValueRangeAnalysis.floor(out)
@staticmethod
def add(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
@staticmethod
def mul(a, b):
return ValueRanges.coordinatewise_monotone_map(a, b, operator.mul)
@staticmethod
def sub(a, b):
b = ValueRanges.wrap(b)
return ValueRangeAnalysis.add(a, ValueRanges(-b.upper, -b.lower))
@staticmethod
def exp(x):
return ValueRanges.increasing_map(x, sympy.functions.elementary.exponential.exp)
@staticmethod
def log(x):
return ValueRanges.increasing_map(
x, lambda y: -math.inf if y <= 0 else sympy.log(y)
)
@staticmethod
def sqrt(x):
return ValueRanges.increasing_map(x, sympy.sqrt)
@staticmethod
def pow(a, b):
def is_integer(val):
return (
isinstance(val, int)
or (isinstance(val, float) and val == int(val))
or (hasattr(val, "is_integer") and val.is_integer)
)
a = ValueRanges.wrap(a)
b = ValueRanges.wrap(b)
if a.lower < 0 and not is_integer(b.lower):
# The function is not defined
return ValueRanges(-math.inf, math.inf)
elif 0 in a and b.lower <= 0:
return ValueRanges(-math.inf, math.inf)
return ValueRanges.coordinatewise_monotone_map(a, b, operator.pow)
@staticmethod
def minimum(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, min)
@staticmethod
def maximum(a, b):
return ValueRanges.coordinatewise_increasing_map(a, b, max)
@staticmethod
def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
return ValueRanges(min(b.lower, c.lower), max(b.upper, c.upper))
@staticmethod
def floor(x):
return ValueRangeAnalysis.floor_ceil(
x, sympy.functions.elementary.integers.floor
)
@staticmethod
def ceil(x):
return ValueRangeAnalysis.floor_ceil(
x, sympy.functions.elementary.integers.ceiling
)
@staticmethod
def floor_ceil(x, fn_int):
def is_integer(val):
return isinstance(val, int) or (
hasattr(val, "is_integer") and val.is_integer
)
if is_integer(x):
fn = fn_int
else:
def fn(x):
return sympy.Float(fn_int(x))
return ValueRanges.increasing_map(x, fn)
def __getattr__(self, name):
developer_warning(f"unhandled ValueRange op {name}")
return self.default_handler
def dominated_nodes(
initial_queue: Union[torch.fx.Node, Iterable[torch.fx.Node]], skip_filter=None
):
"""Returns the set of nodes whose values depend on those within initial_queue"""
if isinstance(initial_queue, torch.fx.Node):
initial_queue = [initial_queue]
dominated_set = set(initial_queue)
while initial_queue:
node = initial_queue.pop()
for user in node.users:
if skip_filter and skip_filter(user):
continue
if user not in dominated_set:
dominated_set.add(user)
initial_queue.append(user)
return dominated_set
def val_expressable_in_32_bits(val):
if hasattr(val, "is_Boolean") and val.is_Boolean:
return True
if isinstance(val, sympy.Expr):
assert val.is_constant()
if val.is_Integer or val.is_Boolean:
val = int(val)
else:
val = float(val)
# bound within mantissa
if isinstance(val, float):
return val <= (2**24) and val >= -(2**24)
if isinstance(val, int):
iinfo = torch.iinfo(torch.int32)
return val <= iinfo.max and val >= iinfo.min
raise Exception(f"Unexpected value {val}")
def range_expressable_in_32_bits(range):
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
range.upper
)
class OptimizeIndexing:
"""
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
intermediaries from int64 to int32. This is an important optimization for indexing
kernels such as Upsample and Interpolate.
"""
def __init__(
self,
loop_body: LoopBody,
indices_ranges: Dict[sympy.Symbol, int],
indexing_exprs: Dict[str, sympy.Expr],
):
self.loop_body = loop_body
self.indices_range = indices_ranges
self.indexing_exprs = indexing_exprs
self.replacement_vals = {}
self.interp_env = {}
self.submodules = self.swap_submodules(dict(loop_body.submodules))
indirect_var_set = set(loop_body.indirect_vars)
self.index_indirect_dependecies = {
index: expr.free_symbols & indirect_var_set
for index, expr in indexing_exprs.items()
}
self.all_graphs = [loop_body.root_block.graph] + [
block.graph for block in loop_body.subblocks.values()
]
for k, v in indices_ranges.items():
self.replace_indirect(k, ValueRanges(0, v))
# avoid computing these values, pessimistically assume that they are unbounded
self.tensor_values_set = dominated_nodes(
[
node
for node in self.all_nodes
if node.target in ["load", "reduction"]
or "masked_subblock" in node.target
]
)
def run(self):
"""Compute Value Ranges and try reduce precision of 'to_dtype' nodes to int32 where possible"""
int64_dtype_nodes = [
node
for node in self.all_nodes
if (
node.target == "to_dtype"
and node.args[2] == torch.int64
and node not in self.tensor_values_set
)
]
if not int64_dtype_nodes:
return
for node in self.tensor_values_set:
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
if (
"masked_subblock" not in node.target
and "set_indirect" not in node.target
):
self.interp_env[node] = torch._inductor.optimize_indexing.ValueRanges(
-math.inf, math.inf
)
interpreter = InterpreterShim(self.loop_body.root_block.graph, self.submodules)
interpreter.run(V.get_ops_handler(), initial_env=self.interp_env)
# TODO - if dominated node of one to_dtype is not expressible in int32,
# we should short circuit another to_dtype node if that node also dominates
for node in int64_dtype_nodes:
self.try_to_reduce_precision(node)
def try_to_reduce_precision(self, node):
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
# then it's precision is set for that chain of uses, and we don't need to consider those
# dominated values
def skip_filter(node):
return node.target == "to_dtype" and node.args[2] in (
torch.int32,
torch.float32,
torch.float64,
)
# TODO - there are dominated uses whose dtype does not depend on whether
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
# int32 without changing the output precision of the node. this case hasn't shown up
for dominated in dominated_nodes(node, skip_filter):
if dominated.target in ["store", "output"]:
continue
if "set_indirect" in dominated.target:
idx = int(dominated.target[len("set_indirect") :])
indirect_var = self.loop_body.indirect_vars[idx]
for index, indirect_vals in self.index_indirect_dependecies.items():
if indirect_var in indirect_vals:
index_val = self.replacement_vals[index]
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
return
# all indices are integers, so make sure that we
# use the bounds of integers instead of floats.
# TODO - not sure if we should be doing int/float casts while tracing,
# might interfere with sympy.
index_val_int = ValueRanges(
int(index_val.lower), int(index_val.upper)
)
if not range_expressable_in_32_bits(index_val_int):
return
if not range_expressable_in_32_bits(self.interp_env[dominated]):
return
args = list(node.args)
args[2] = torch.int32
node.args = tuple(args)
@property
def all_nodes(self):
for graph in self.all_graphs:
for node in graph.nodes:
yield node
def swap_submodules(self, submodules):
keys = list(submodules.keys())
for key in keys:
if key == "get_index":
submodules[key] = self.get_index
elif "masked_subblock" in key:
subblock = self.loop_body.subblocks[key]
submodules[key] = functools.partial(
self.masked_subblock, subblock, self.interp_env
)
else:
assert "set_indirect" in key
idx = int(key[len("set_indirect") :])
var = self.loop_body.indirect_vars[idx]
indirect = functools.partial(self.set_indirect, var)
submodules[key] = indirect
return submodules
def masked_subblock(self, subblock, env, mask, value):
interp = InterpreterShim(subblock.graph, self.submodules)
interp.run(V.get_ops_handler(), initial_env=env)
output = [node for node in subblock.graph.nodes if node.target == "output"]
assert len(output) == 1
# dont bother unioning with value since the load from buffer will be
# pessimistically assumed to be inf anyway
return interp.env[output[0]]
def set_indirect(self, var, new_var):
self.replace_indirect(var, new_var)
return new_var
def replace_indirect(self, old, new):
"""Swap in a variable used in indirect indexing"""
assert isinstance(new, ValueRanges)
self.replacement_vals[old] = new
def get_index(self, name):
if name in self.replacement_vals:
return self.replacement_vals[name]
out = self._get_index_impl(name)
self.replacement_vals[name] = out
return out
def _get_index_impl(self, name):
expr = self.indexing_exprs[name]
free_symbols = list(expr.free_symbols)
if len(free_symbols) == 0:
return ValueRanges(expr, expr)
if expr in self.replacement_vals:
return self.replacement_vals[expr]
def replace_symbols_for_deriv(expr, ignore_mod=False):
# for the purposes of finding local, minimum, maximum, assume smoothness
def mod_indexing_rep(x, y, z):
if z.is_constant():
return x / y
# never really happens, we'll bail on optimizing
return (x / y) % z
def indexing_div_rep(x, y):
return x / y
return expr.replace(ModularIndexing, mod_indexing_rep).replace(
FloorDiv, indexing_div_rep
)
symbols = expr.free_symbols
monotonic_increasing = []
monotonic_decreasing = []
other_symbols = []
expr_for_deriv = replace_symbols_for_deriv(expr, True)
for symbol in symbols:
diff = sympy.diff(expr_for_deriv, symbol)
if diff.is_positive:
monotonic_increasing.append(symbol)
elif diff.is_positive is False: # can return None
monotonic_decreasing.append(symbol)
else:
other_symbols.append(symbol)
if not other_symbols:
max_val = sympy_subs(
expr,
{
k: (v.upper if k in monotonic_increasing else v.lower)
for k, v in self.replacement_vals.items()
},
)
min_val = sympy_subs(
expr,
{
k: (v.lower if k in monotonic_increasing else v.upper)
for k, v in self.replacement_vals.items()
},
)
return ValueRanges(min_val, max_val)
else:
# bail on optimizing, have not run into this yet
return ValueRanges(-math.inf, math.inf)
def indexing_dtype_strength_reduction(loop_body: LoopBody):
"""
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
intermediaries from int64 to int32
"""
indices = dict(loop_body.var_ranges)
indexing = dict(loop_body.indexing_exprs)
with V.set_ops_handler(ValueRangeAnalysis()):
OptimizeIndexing(loop_body, indices, indexing).run()