mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94413 Approved by: https://github.com/ezyang
577 lines
19 KiB
Python
577 lines
19 KiB
Python
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import math
|
|
import operator
|
|
from typing import Dict, Iterable, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from .ir import FloorDiv, InterpreterShim, LoopBody, ModularIndexing
|
|
from .utils import sympy_subs
|
|
from .virtualized import V
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ValueRanges:
|
|
lower: Union[sympy.Expr, sympy.Number, int, float, bool]
|
|
upper: Union[sympy.Expr, sympy.Number, int, float, bool]
|
|
|
|
def __contains__(self, x):
|
|
# TODO This needs to be generalised if lower/upper are sympy.Expr
|
|
assert not isinstance(x, sympy.Expr)
|
|
return self.lower <= x <= self.upper
|
|
|
|
@classmethod
|
|
def wrap(cls, arg):
|
|
if isinstance(arg, ValueRanges):
|
|
return arg
|
|
assert isinstance(arg, (int, float, bool))
|
|
return ValueRanges(arg, arg)
|
|
|
|
@classmethod
|
|
def increasing_map(cls, x, fn):
|
|
"""map lower and upper bound with fn"""
|
|
x = cls.wrap(x)
|
|
return ValueRanges(fn(x.lower), fn(x.upper))
|
|
|
|
@classmethod
|
|
def decreasing_map(cls, x, fn):
|
|
"""map lower bound to upper bound and upper bound to lower bound"""
|
|
x = cls.wrap(x)
|
|
return ValueRanges(fn(x.upper), fn(x.lower))
|
|
|
|
@classmethod
|
|
def monotone_map(cls, x, fn):
|
|
"""check the max and min of computed upper and lower bound for the output"""
|
|
x = cls.wrap(x)
|
|
l = fn(x.lower)
|
|
u = fn(x.upper)
|
|
return ValueRanges(min(l, u), max(l, u))
|
|
|
|
@classmethod
|
|
def convex_min_zero_map(cls, x, fn):
|
|
"""the max is at one of the ends"""
|
|
x = ValueRanges.wrap(x)
|
|
if 0 in x:
|
|
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
|
|
else:
|
|
return cls.monotone_map(x, fn)
|
|
|
|
@classmethod
|
|
def coordinatewise_increasing_map(cls, x, y, fn):
|
|
"""map upper and lower bounds accessing corresponding values of inputs"""
|
|
x, y = cls.wrap(x), cls.wrap(y)
|
|
return ValueRanges(
|
|
fn(x.lower, y.lower),
|
|
fn(x.upper, y.upper),
|
|
)
|
|
|
|
@classmethod
|
|
def coordinatewise_monotone_map(cls, x, y, fn):
|
|
"""compute the product of all lower and upper bounds and take min and max"""
|
|
x, y = cls.wrap(x), cls.wrap(y)
|
|
products = [
|
|
fn(a, b)
|
|
for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
|
|
]
|
|
return ValueRanges(min(products), max(products))
|
|
|
|
|
|
class ValueRangeAnalysis:
|
|
def __init__(self):
|
|
self.name = "ValueRangeAnalysis"
|
|
boolean_operators = (
|
|
"eq",
|
|
"ne",
|
|
"lt",
|
|
"gt",
|
|
"le",
|
|
"ge",
|
|
"and_",
|
|
"or_",
|
|
"xor",
|
|
"logical_and",
|
|
"logical_or",
|
|
"logical_not",
|
|
)
|
|
for op in boolean_operators:
|
|
setattr(self, op, self.bool_handler)
|
|
|
|
@staticmethod
|
|
def bool_handler(*args, **kwargs):
|
|
# just assuming bools can have both values
|
|
return ValueRanges(sympy.false, sympy.true)
|
|
|
|
@staticmethod
|
|
def default_handler(*args, **kwargs):
|
|
# many ops are unlikely to show up in optimizable indexing compute,
|
|
# so we dont have full coverage
|
|
return ValueRanges(-math.inf, math.inf)
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
return ValueRanges(-math.inf, math.inf)
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
return
|
|
|
|
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
|
return ValueRanges(-math.inf, math.inf)
|
|
|
|
def index_expr(self, index, dtype):
|
|
assert isinstance(index, ValueRanges)
|
|
return index
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype: torch.dtype):
|
|
def is_bool(val):
|
|
return isinstance(val, bool) or (
|
|
hasattr(val, "is_Boolean") and val.is_Boolean
|
|
)
|
|
|
|
x = ValueRanges.wrap(x)
|
|
low, up = x.lower, x.upper
|
|
if is_bool(low):
|
|
assert is_bool(up)
|
|
if dtype.is_floating_point:
|
|
return ValueRanges(sympy.Float(0.0), sympy.Float(1.0))
|
|
else:
|
|
return ValueRanges(sympy.Integer(0), sympy.Integer(1))
|
|
return ValueRanges.wrap(x)
|
|
|
|
@staticmethod
|
|
def constant(value, dtype):
|
|
# using nan makes subsequent computation throw, and for the purposes of optimization
|
|
# returning -math.inf - math.inf is equivalent to giving up
|
|
if math.isnan(value):
|
|
return ValueRanges(-math.inf, math.inf)
|
|
if isinstance(value, int):
|
|
return ValueRanges(sympy.Integer(value), sympy.Integer(value))
|
|
else:
|
|
return ValueRanges(sympy.Float(value), sympy.Float(value))
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
x = ValueRanges.wrap(x)
|
|
if 0 in x:
|
|
return ValueRanges(-math.inf, math.inf)
|
|
else:
|
|
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return ValueRanges.convex_min_zero_map(x, abs)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return ValueRanges.decreasing_map(x, operator.neg)
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
b = ValueRanges.wrap(b)
|
|
if 0 in b:
|
|
return ValueRanges(-math.inf, math.inf)
|
|
else:
|
|
return ValueRangeAnalysis.mul(a, ValueRanges(1 / b.upper, 1 / b.lower))
|
|
|
|
@staticmethod
|
|
def div(a, b):
|
|
# We think of this as floor(a / b)
|
|
out = ValueRangeAnalysis.truediv(a, b)
|
|
return ValueRangeAnalysis.floor(out)
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return ValueRanges.coordinatewise_monotone_map(a, b, operator.mul)
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
b = ValueRanges.wrap(b)
|
|
return ValueRangeAnalysis.add(a, ValueRanges(-b.upper, -b.lower))
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return ValueRanges.increasing_map(x, sympy.functions.elementary.exponential.exp)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return ValueRanges.increasing_map(
|
|
x, lambda y: -math.inf if y <= 0 else sympy.log(y)
|
|
)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return ValueRanges.increasing_map(x, sympy.sqrt)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
def is_integer(val):
|
|
return (
|
|
isinstance(val, int)
|
|
or (isinstance(val, float) and val == int(val))
|
|
or (hasattr(val, "is_integer") and val.is_integer)
|
|
)
|
|
|
|
a = ValueRanges.wrap(a)
|
|
b = ValueRanges.wrap(b)
|
|
if a.lower < 0 and not is_integer(b.lower):
|
|
# The function is not defined
|
|
return ValueRanges(-math.inf, math.inf)
|
|
elif 0 in a and b.lower <= 0:
|
|
return ValueRanges(-math.inf, math.inf)
|
|
return ValueRanges.coordinatewise_monotone_map(a, b, operator.pow)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return ValueRanges.coordinatewise_increasing_map(a, b, min)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return ValueRanges.coordinatewise_increasing_map(a, b, max)
|
|
|
|
@staticmethod
|
|
def where(a, b, c):
|
|
b = ValueRanges.wrap(b)
|
|
c = ValueRanges.wrap(c)
|
|
return ValueRanges(min(b.lower, c.lower), max(b.upper, c.upper))
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return ValueRangeAnalysis.floor_ceil(
|
|
x, sympy.functions.elementary.integers.floor
|
|
)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return ValueRangeAnalysis.floor_ceil(
|
|
x, sympy.functions.elementary.integers.ceiling
|
|
)
|
|
|
|
@staticmethod
|
|
def floor_ceil(x, fn_int):
|
|
def is_integer(val):
|
|
return isinstance(val, int) or (
|
|
hasattr(val, "is_integer") and val.is_integer
|
|
)
|
|
|
|
if is_integer(x):
|
|
fn = fn_int
|
|
else:
|
|
|
|
def fn(x):
|
|
return sympy.Float(fn_int(x))
|
|
|
|
return ValueRanges.increasing_map(x, fn)
|
|
|
|
def __getattr__(self, name):
|
|
developer_warning(f"unhandled ValueRange op {name}")
|
|
return self.default_handler
|
|
|
|
|
|
def dominated_nodes(
|
|
initial_queue: Union[torch.fx.Node, Iterable[torch.fx.Node]], skip_filter=None
|
|
):
|
|
"""Returns the set of nodes whose values depend on those within initial_queue"""
|
|
if isinstance(initial_queue, torch.fx.Node):
|
|
initial_queue = [initial_queue]
|
|
|
|
dominated_set = set(initial_queue)
|
|
|
|
while initial_queue:
|
|
node = initial_queue.pop()
|
|
for user in node.users:
|
|
if skip_filter and skip_filter(user):
|
|
continue
|
|
if user not in dominated_set:
|
|
dominated_set.add(user)
|
|
initial_queue.append(user)
|
|
|
|
return dominated_set
|
|
|
|
|
|
def val_expressable_in_32_bits(val):
|
|
if hasattr(val, "is_Boolean") and val.is_Boolean:
|
|
return True
|
|
|
|
if isinstance(val, sympy.Expr):
|
|
assert val.is_constant()
|
|
if val.is_Integer or val.is_Boolean:
|
|
val = int(val)
|
|
else:
|
|
val = float(val)
|
|
|
|
# bound within mantissa
|
|
if isinstance(val, float):
|
|
return val <= (2**24) and val >= -(2**24)
|
|
|
|
if isinstance(val, int):
|
|
iinfo = torch.iinfo(torch.int32)
|
|
return val <= iinfo.max and val >= iinfo.min
|
|
|
|
raise Exception(f"Unexpected value {val}")
|
|
|
|
|
|
def range_expressable_in_32_bits(range):
|
|
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
|
|
range.upper
|
|
)
|
|
|
|
|
|
class OptimizeIndexing:
|
|
"""
|
|
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
|
|
intermediaries from int64 to int32. This is an important optimization for indexing
|
|
kernels such as Upsample and Interpolate.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
loop_body: LoopBody,
|
|
indices_ranges: Dict[sympy.Symbol, int],
|
|
indexing_exprs: Dict[str, sympy.Expr],
|
|
):
|
|
self.loop_body = loop_body
|
|
self.indices_range = indices_ranges
|
|
self.indexing_exprs = indexing_exprs
|
|
self.replacement_vals = {}
|
|
self.interp_env = {}
|
|
self.submodules = self.swap_submodules(dict(loop_body.submodules))
|
|
|
|
indirect_var_set = set(loop_body.indirect_vars)
|
|
self.index_indirect_dependecies = {
|
|
index: expr.free_symbols & indirect_var_set
|
|
for index, expr in indexing_exprs.items()
|
|
}
|
|
self.all_graphs = [loop_body.root_block.graph] + [
|
|
block.graph for block in loop_body.subblocks.values()
|
|
]
|
|
|
|
for k, v in indices_ranges.items():
|
|
self.replace_indirect(k, ValueRanges(0, v))
|
|
|
|
# avoid computing these values, pessimistically assume that they are unbounded
|
|
self.tensor_values_set = dominated_nodes(
|
|
[
|
|
node
|
|
for node in self.all_nodes
|
|
if node.target in ["load", "reduction"]
|
|
or "masked_subblock" in node.target
|
|
]
|
|
)
|
|
|
|
def run(self):
|
|
"""Compute Value Ranges and try reduce precision of 'to_dtype' nodes to int32 where possible"""
|
|
|
|
int64_dtype_nodes = [
|
|
node
|
|
for node in self.all_nodes
|
|
if (
|
|
node.target == "to_dtype"
|
|
and node.args[2] == torch.int64
|
|
and node not in self.tensor_values_set
|
|
)
|
|
]
|
|
if not int64_dtype_nodes:
|
|
return
|
|
|
|
for node in self.tensor_values_set:
|
|
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
|
|
if (
|
|
"masked_subblock" not in node.target
|
|
and "set_indirect" not in node.target
|
|
):
|
|
self.interp_env[node] = torch._inductor.optimize_indexing.ValueRanges(
|
|
-math.inf, math.inf
|
|
)
|
|
|
|
interpreter = InterpreterShim(self.loop_body.root_block.graph, self.submodules)
|
|
interpreter.run(V.get_ops_handler(), initial_env=self.interp_env)
|
|
|
|
# TODO - if dominated node of one to_dtype is not expressible in int32,
|
|
# we should short circuit another to_dtype node if that node also dominates
|
|
for node in int64_dtype_nodes:
|
|
self.try_to_reduce_precision(node)
|
|
|
|
def try_to_reduce_precision(self, node):
|
|
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
|
|
# then it's precision is set for that chain of uses, and we don't need to consider those
|
|
# dominated values
|
|
def skip_filter(node):
|
|
return node.target == "to_dtype" and node.args[2] in (
|
|
torch.int32,
|
|
torch.float32,
|
|
torch.float64,
|
|
)
|
|
|
|
# TODO - there are dominated uses whose dtype does not depend on whether
|
|
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
|
|
# int32 without changing the output precision of the node. this case hasn't shown up
|
|
for dominated in dominated_nodes(node, skip_filter):
|
|
if dominated.target in ["store", "output"]:
|
|
continue
|
|
|
|
if "set_indirect" in dominated.target:
|
|
idx = int(dominated.target[len("set_indirect") :])
|
|
indirect_var = self.loop_body.indirect_vars[idx]
|
|
|
|
for index, indirect_vals in self.index_indirect_dependecies.items():
|
|
if indirect_var in indirect_vals:
|
|
index_val = self.replacement_vals[index]
|
|
|
|
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
|
|
return
|
|
|
|
# all indices are integers, so make sure that we
|
|
# use the bounds of integers instead of floats.
|
|
# TODO - not sure if we should be doing int/float casts while tracing,
|
|
# might interfere with sympy.
|
|
|
|
index_val_int = ValueRanges(
|
|
int(index_val.lower), int(index_val.upper)
|
|
)
|
|
if not range_expressable_in_32_bits(index_val_int):
|
|
return
|
|
|
|
if not range_expressable_in_32_bits(self.interp_env[dominated]):
|
|
return
|
|
|
|
args = list(node.args)
|
|
args[2] = torch.int32
|
|
node.args = tuple(args)
|
|
|
|
@property
|
|
def all_nodes(self):
|
|
for graph in self.all_graphs:
|
|
for node in graph.nodes:
|
|
yield node
|
|
|
|
def swap_submodules(self, submodules):
|
|
keys = list(submodules.keys())
|
|
for key in keys:
|
|
if key == "get_index":
|
|
submodules[key] = self.get_index
|
|
elif "masked_subblock" in key:
|
|
subblock = self.loop_body.subblocks[key]
|
|
submodules[key] = functools.partial(
|
|
self.masked_subblock, subblock, self.interp_env
|
|
)
|
|
else:
|
|
assert "set_indirect" in key
|
|
idx = int(key[len("set_indirect") :])
|
|
var = self.loop_body.indirect_vars[idx]
|
|
indirect = functools.partial(self.set_indirect, var)
|
|
submodules[key] = indirect
|
|
|
|
return submodules
|
|
|
|
def masked_subblock(self, subblock, env, mask, value):
|
|
interp = InterpreterShim(subblock.graph, self.submodules)
|
|
interp.run(V.get_ops_handler(), initial_env=env)
|
|
output = [node for node in subblock.graph.nodes if node.target == "output"]
|
|
assert len(output) == 1
|
|
# dont bother unioning with value since the load from buffer will be
|
|
# pessimistically assumed to be inf anyway
|
|
return interp.env[output[0]]
|
|
|
|
def set_indirect(self, var, new_var):
|
|
self.replace_indirect(var, new_var)
|
|
return new_var
|
|
|
|
def replace_indirect(self, old, new):
|
|
"""Swap in a variable used in indirect indexing"""
|
|
assert isinstance(new, ValueRanges)
|
|
self.replacement_vals[old] = new
|
|
|
|
def get_index(self, name):
|
|
if name in self.replacement_vals:
|
|
return self.replacement_vals[name]
|
|
|
|
out = self._get_index_impl(name)
|
|
self.replacement_vals[name] = out
|
|
return out
|
|
|
|
def _get_index_impl(self, name):
|
|
expr = self.indexing_exprs[name]
|
|
|
|
free_symbols = list(expr.free_symbols)
|
|
|
|
if len(free_symbols) == 0:
|
|
return ValueRanges(expr, expr)
|
|
|
|
if expr in self.replacement_vals:
|
|
return self.replacement_vals[expr]
|
|
|
|
def replace_symbols_for_deriv(expr, ignore_mod=False):
|
|
# for the purposes of finding local, minimum, maximum, assume smoothness
|
|
def mod_indexing_rep(x, y, z):
|
|
if z.is_constant():
|
|
return x / y
|
|
|
|
# never really happens, we'll bail on optimizing
|
|
return (x / y) % z
|
|
|
|
def indexing_div_rep(x, y):
|
|
return x / y
|
|
|
|
return expr.replace(ModularIndexing, mod_indexing_rep).replace(
|
|
FloorDiv, indexing_div_rep
|
|
)
|
|
|
|
symbols = expr.free_symbols
|
|
monotonic_increasing = []
|
|
monotonic_decreasing = []
|
|
other_symbols = []
|
|
|
|
expr_for_deriv = replace_symbols_for_deriv(expr, True)
|
|
for symbol in symbols:
|
|
diff = sympy.diff(expr_for_deriv, symbol)
|
|
if diff.is_positive:
|
|
monotonic_increasing.append(symbol)
|
|
elif diff.is_positive is False: # can return None
|
|
monotonic_decreasing.append(symbol)
|
|
else:
|
|
other_symbols.append(symbol)
|
|
|
|
if not other_symbols:
|
|
max_val = sympy_subs(
|
|
expr,
|
|
{
|
|
k: (v.upper if k in monotonic_increasing else v.lower)
|
|
for k, v in self.replacement_vals.items()
|
|
},
|
|
)
|
|
min_val = sympy_subs(
|
|
expr,
|
|
{
|
|
k: (v.lower if k in monotonic_increasing else v.upper)
|
|
for k, v in self.replacement_vals.items()
|
|
},
|
|
)
|
|
return ValueRanges(min_val, max_val)
|
|
else:
|
|
# bail on optimizing, have not run into this yet
|
|
return ValueRanges(-math.inf, math.inf)
|
|
|
|
|
|
def indexing_dtype_strength_reduction(loop_body: LoopBody):
|
|
"""
|
|
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
|
|
intermediaries from int64 to int32
|
|
"""
|
|
indices = dict(loop_body.var_ranges)
|
|
indexing = dict(loop_body.indexing_exprs)
|
|
with V.set_ops_handler(ValueRangeAnalysis()):
|
|
OptimizeIndexing(loop_body, indices, indexing).run()
|