mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Move ValueRanges into its own module (#94528)
I am going to use it in ShapeEnv shortly. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94528 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
bae397ec63
commit
50bc25baa0
@ -1,14 +1,12 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from .ir import FloorDiv, InterpreterShim, LoopBody, ModularIndexing
|
||||
from .utils import sympy_subs
|
||||
from .virtualized import V
|
||||
@ -16,270 +14,6 @@ from .virtualized import V
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ValueRanges:
|
||||
lower: Union[sympy.Expr, sympy.Number, int, float, bool]
|
||||
upper: Union[sympy.Expr, sympy.Number, int, float, bool]
|
||||
|
||||
def __contains__(self, x):
|
||||
# TODO This needs to be generalised if lower/upper are sympy.Expr
|
||||
assert not isinstance(x, sympy.Expr)
|
||||
return self.lower <= x <= self.upper
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, arg):
|
||||
if isinstance(arg, ValueRanges):
|
||||
return arg
|
||||
assert isinstance(arg, (int, float, bool))
|
||||
return ValueRanges(arg, arg)
|
||||
|
||||
@classmethod
|
||||
def increasing_map(cls, x, fn):
|
||||
"""map lower and upper bound with fn"""
|
||||
x = cls.wrap(x)
|
||||
return ValueRanges(fn(x.lower), fn(x.upper))
|
||||
|
||||
@classmethod
|
||||
def decreasing_map(cls, x, fn):
|
||||
"""map lower bound to upper bound and upper bound to lower bound"""
|
||||
x = cls.wrap(x)
|
||||
return ValueRanges(fn(x.upper), fn(x.lower))
|
||||
|
||||
@classmethod
|
||||
def monotone_map(cls, x, fn):
|
||||
"""check the max and min of computed upper and lower bound for the output"""
|
||||
x = cls.wrap(x)
|
||||
l = fn(x.lower)
|
||||
u = fn(x.upper)
|
||||
return ValueRanges(min(l, u), max(l, u))
|
||||
|
||||
@classmethod
|
||||
def convex_min_zero_map(cls, x, fn):
|
||||
"""the max is at one of the ends"""
|
||||
x = ValueRanges.wrap(x)
|
||||
if 0 in x:
|
||||
return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
|
||||
else:
|
||||
return cls.monotone_map(x, fn)
|
||||
|
||||
@classmethod
|
||||
def coordinatewise_increasing_map(cls, x, y, fn):
|
||||
"""map upper and lower bounds accessing corresponding values of inputs"""
|
||||
x, y = cls.wrap(x), cls.wrap(y)
|
||||
return ValueRanges(
|
||||
fn(x.lower, y.lower),
|
||||
fn(x.upper, y.upper),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def coordinatewise_monotone_map(cls, x, y, fn):
|
||||
"""compute the product of all lower and upper bounds and take min and max"""
|
||||
x, y = cls.wrap(x), cls.wrap(y)
|
||||
products = [
|
||||
fn(a, b)
|
||||
for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
|
||||
]
|
||||
return ValueRanges(min(products), max(products))
|
||||
|
||||
|
||||
class ValueRangeAnalysis:
|
||||
def __init__(self):
|
||||
self.name = "ValueRangeAnalysis"
|
||||
boolean_operators = (
|
||||
"eq",
|
||||
"ne",
|
||||
"lt",
|
||||
"gt",
|
||||
"le",
|
||||
"ge",
|
||||
"and_",
|
||||
"or_",
|
||||
"xor",
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
"logical_not",
|
||||
)
|
||||
for op in boolean_operators:
|
||||
setattr(self, op, self.bool_handler)
|
||||
|
||||
@staticmethod
|
||||
def bool_handler(*args, **kwargs):
|
||||
# just assuming bools can have both values
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
||||
@staticmethod
|
||||
def default_handler(*args, **kwargs):
|
||||
# many ops are unlikely to show up in optimizable indexing compute,
|
||||
# so we dont have full coverage
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
return
|
||||
|
||||
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
|
||||
def index_expr(self, index, dtype):
|
||||
assert isinstance(index, ValueRanges)
|
||||
return index
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(x, dtype: torch.dtype):
|
||||
def is_bool(val):
|
||||
return isinstance(val, bool) or (
|
||||
hasattr(val, "is_Boolean") and val.is_Boolean
|
||||
)
|
||||
|
||||
x = ValueRanges.wrap(x)
|
||||
low, up = x.lower, x.upper
|
||||
if is_bool(low):
|
||||
assert is_bool(up)
|
||||
if dtype.is_floating_point:
|
||||
return ValueRanges(sympy.Float(0.0), sympy.Float(1.0))
|
||||
else:
|
||||
return ValueRanges(sympy.Integer(0), sympy.Integer(1))
|
||||
return ValueRanges.wrap(x)
|
||||
|
||||
@staticmethod
|
||||
def constant(value, dtype):
|
||||
# using nan makes subsequent computation throw, and for the purposes of optimization
|
||||
# returning -math.inf - math.inf is equivalent to giving up
|
||||
if math.isnan(value):
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
if isinstance(value, int):
|
||||
return ValueRanges(sympy.Integer(value), sympy.Integer(value))
|
||||
else:
|
||||
return ValueRanges(sympy.Float(value), sympy.Float(value))
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(x):
|
||||
x = ValueRanges.wrap(x)
|
||||
if 0 in x:
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
else:
|
||||
return ValueRanges.decreasing_map(x, lambda y: 1 / y)
|
||||
|
||||
@staticmethod
|
||||
def square(x):
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
|
||||
|
||||
@staticmethod
|
||||
def abs(x):
|
||||
return ValueRanges.convex_min_zero_map(x, abs)
|
||||
|
||||
@staticmethod
|
||||
def neg(x):
|
||||
return ValueRanges.decreasing_map(x, operator.neg)
|
||||
|
||||
@staticmethod
|
||||
def truediv(a, b):
|
||||
b = ValueRanges.wrap(b)
|
||||
if 0 in b:
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
else:
|
||||
return ValueRangeAnalysis.mul(a, ValueRanges(1 / b.upper, 1 / b.lower))
|
||||
|
||||
@staticmethod
|
||||
def div(a, b):
|
||||
# We think of this as floor(a / b)
|
||||
out = ValueRangeAnalysis.truediv(a, b)
|
||||
return ValueRangeAnalysis.floor(out)
|
||||
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
|
||||
|
||||
@staticmethod
|
||||
def mul(a, b):
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, operator.mul)
|
||||
|
||||
@staticmethod
|
||||
def sub(a, b):
|
||||
b = ValueRanges.wrap(b)
|
||||
return ValueRangeAnalysis.add(a, ValueRanges(-b.upper, -b.lower))
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
return ValueRanges.increasing_map(x, sympy.functions.elementary.exponential.exp)
|
||||
|
||||
@staticmethod
|
||||
def log(x):
|
||||
return ValueRanges.increasing_map(
|
||||
x, lambda y: -math.inf if y <= 0 else sympy.log(y)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sqrt(x):
|
||||
return ValueRanges.increasing_map(x, sympy.sqrt)
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
def is_integer(val):
|
||||
return (
|
||||
isinstance(val, int)
|
||||
or (isinstance(val, float) and val == int(val))
|
||||
or (hasattr(val, "is_integer") and val.is_integer)
|
||||
)
|
||||
|
||||
a = ValueRanges.wrap(a)
|
||||
b = ValueRanges.wrap(b)
|
||||
if a.lower < 0 and not is_integer(b.lower):
|
||||
# The function is not defined
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
elif 0 in a and b.lower <= 0:
|
||||
return ValueRanges(-math.inf, math.inf)
|
||||
return ValueRanges.coordinatewise_monotone_map(a, b, operator.pow)
|
||||
|
||||
@staticmethod
|
||||
def minimum(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, min)
|
||||
|
||||
@staticmethod
|
||||
def maximum(a, b):
|
||||
return ValueRanges.coordinatewise_increasing_map(a, b, max)
|
||||
|
||||
@staticmethod
|
||||
def where(a, b, c):
|
||||
b = ValueRanges.wrap(b)
|
||||
c = ValueRanges.wrap(c)
|
||||
return ValueRanges(min(b.lower, c.lower), max(b.upper, c.upper))
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return ValueRangeAnalysis.floor_ceil(
|
||||
x, sympy.functions.elementary.integers.floor
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return ValueRangeAnalysis.floor_ceil(
|
||||
x, sympy.functions.elementary.integers.ceiling
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def floor_ceil(x, fn_int):
|
||||
def is_integer(val):
|
||||
return isinstance(val, int) or (
|
||||
hasattr(val, "is_integer") and val.is_integer
|
||||
)
|
||||
|
||||
if is_integer(x):
|
||||
fn = fn_int
|
||||
else:
|
||||
|
||||
def fn(x):
|
||||
return sympy.Float(fn_int(x))
|
||||
|
||||
return ValueRanges.increasing_map(x, fn)
|
||||
|
||||
def __getattr__(self, name):
|
||||
developer_warning(f"unhandled ValueRange op {name}")
|
||||
return self.default_handler
|
||||
|
||||
|
||||
def dominated_nodes(
|
||||
initial_queue: Union[torch.fx.Node, Iterable[torch.fx.Node]], skip_filter=None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user