diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 7eedf77f6e09..e8c422952e85 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -46,13 +46,18 @@ if [[ "\$python_nodot" = *310* ]]; then PROTOBUF_PACKAGE="protobuf>=3.19.0" fi -if [[ "\$python_nodot" = *39* ]]; then +if [[ "\$python_nodot" = *39* ]]; then # There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20 # we set a lower boundary here just to be safe NUMPY_PIN=">=1.20" fi - +if [[ "\$python_nodot" = *38* ]]; then + # sympy 1.12.1 is the last version that supports Python 3.8 + SYMPY_PIN="==1.12.1" +else + SYMPY_PIN=">=1.13.0" +fi # Move debug wheels out of the package dir so they don't get installed mkdir -p /tmp/debug_final_pkgs @@ -83,7 +88,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then "numpy\${NUMPY_PIN}" \ mkl>=2018 \ ninja \ - sympy \ + "sympy\${SYMPY_PIN}" \ typing-extensions \ ${PROTOBUF_PACKAGE} if [[ "$DESIRED_CUDA" == 'cpu' ]]; then diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index f0e4890328b3..beac4838231e 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -17,11 +17,11 @@ pytest-xdist==3.3.1 pytest-rerunfailures==10.3 pytest-flakefinder==1.1.0 scipy==1.10.1 -sympy==1.11.1 +sympy==1.12.1 ; python_version == "3.8" +sympy>=1.13.0 ; python_version >= "3.9" unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 filelock==3.6.0 -sympy==1.11.1 pytest-cpp==2.3.0 rockset==1.0.3 z3-solver==4.12.2.0 diff --git a/.lintrunner.toml b/.lintrunner.toml index 5a661d383542..b1ecad45fdf8 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -139,7 +139,8 @@ init_command = [ 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.1.6', 'mypy==1.10.0', - 'sympy==1.11.1', + 'sympy==1.12.1 ; python_version == "3.8"', + 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', 'types-tabulate==0.8.8', diff --git a/requirements.txt b/requirements.txt index 95a30dd6e599..26f6305236ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,8 @@ requests setuptools types-dataclasses typing-extensions>=4.8.0 -sympy +sympy==1.12.1 ; python_version == "3.8" +sympy>=1.13.0 ; python_version >= "3.9" filelock networkx jinja2 diff --git a/setup.py b/setup.py index dec89e0a813b..927596da47a5 100644 --- a/setup.py +++ b/setup.py @@ -1138,7 +1138,8 @@ def main(): install_requires = [ "filelock", "typing-extensions>=4.8.0", - "sympy", + 'sympy==1.12.1 ; python_version == "3.8"', + 'sympy>=1.13.0 ; python_version >= "3.9"', "networkx", "jinja2", "fsspec", diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 723c240c40bf..a11ce168631c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -29,6 +29,7 @@ import torch import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges @@ -1855,13 +1856,13 @@ class Kernel(CodeGen): # Take the negative part of the bound and add size to it # Then take union of that and the positive part # This is a tighter bound than that of a generic ops.where, as we have info on the cond - neg_bounds = var.bounds & ValueRanges(-sympy.oo, -1) + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) new_bounds = ValueRanges( neg_bounds.lower + size, neg_bounds.upper + size ) # We don't have a good way of representing the empty range if var.bounds.upper >= 0: # type: ignore[operator] - pos = var.bounds & ValueRanges(0, sympy.oo) + pos = var.bounds & ValueRanges(0, int_oo) new_bounds = new_bounds | pos var = self.cse.generate(self.compute, stm, bounds=new_bounds) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index adb4208f07fa..756e37b84607 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -659,7 +659,7 @@ def pformat(obj: Any) -> str: obj = sorted(obj, key=str) result = pprint.pformat(obj, indent=4) if "\n" in result: - return f"\n{textwrap.indent(result, ' '*4)}" + return f"\n{textwrap.indent(result, ' ' * 4)}" return result @@ -1675,7 +1675,7 @@ class Scheduler: NodeUser(user_node, can_inplace, is_weak) ) - unbacked_symbol_to_origin_node = {} + unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {} # NB: None means that the dependency is on an input. Don't actually # generate a dependency because if we do, Inductor will start trying diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 0d4f9cd6203c..3d04b35b0b80 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -6,6 +6,7 @@ import sys import sympy from sympy import S +from sympy.core.numbers import equal_valued from .numbers import int_oo @@ -117,9 +118,9 @@ class FloorDiv(sympy.Function): if base.is_zero: return sympy.S.Zero - if base.is_integer and divisor == 1: + if base.is_integer and equal_valued(divisor, 1): return base - if base.is_integer and divisor == -1: + if base.is_integer and equal_valued(divisor, -1): return sympy.Mul(base, -1) if ( isinstance(base, sympy.Number) @@ -155,7 +156,7 @@ class FloorDiv(sympy.Function): try: gcd = sympy.gcd(base, divisor) - if gcd != 1: + if not equal_valued(gcd, 1): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index 6a93255df852..d02b9879cad2 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -60,8 +60,8 @@ class IntInfinity(Number, metaclass=Singleton): @_sympifyit("other", NotImplemented) def __add__(self, other): if isinstance(other, Number) and global_parameters.evaluate: - if other is S.NegativeInfinity: - return S.NegativeInfinity + if other in (S.Infinity, S.NegativeInfinity): + return other if other in (S.NegativeIntInfinity, S.NaN): return S.NaN return self @@ -74,6 +74,8 @@ class IntInfinity(Number, metaclass=Singleton): if isinstance(other, Number) and global_parameters.evaluate: if other is S.Infinity: return S.NegativeInfinity + if other is S.NegativeInfinity: + return S.Infinity if other in (S.IntInfinity, S.NaN): return S.NaN return self diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 759f4e108337..4a01d8e53b91 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -136,11 +136,19 @@ class ValueRanges(Generic[_T]): return f"VR[{self.lower}, {self.upper}]" @overload - def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: + def __init__( + self: ValueRanges[sympy.Expr], + lower: ExprIn, + upper: ExprIn, + ) -> None: ... @overload - def __init__(self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn) -> None: + def __init__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + lower: BoolIn, + upper: BoolIn, + ) -> None: ... def __init__(self, lower: AllIn, upper: AllIn) -> None: @@ -153,13 +161,10 @@ class ValueRanges(Generic[_T]): raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") except TypeError as e: raise TypeError(f"Could not compare {lower} <= {upper}") from e - # Because this is a frozen class - object.__setattr__(self, "lower", lower) - object.__setattr__(self, "upper", upper) - # Unlike bool/int in Python, we don't report bools are ints - object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - if self.is_bool: - assert isinstance(upper, SympyBoolean), (lower, upper) + + is_bool_lower = isinstance(lower, SympyBoolean) + is_bool_upper = isinstance(upper, SympyBoolean) + assert is_bool_lower == is_bool_upper, (lower, upper) # Warning: is_int/is_float is best effort. We do pretty well in # Dynamo, but in Inductor these attributes are often wrong because we @@ -167,15 +172,26 @@ class ValueRanges(Generic[_T]): # the flexible analysis for is_int: sometimes a sympy.oo pops in for # an integer bound. I would /like/ for us not to do this, but it's # too hard to push the invariant through right now. + if isinstance(lower, sympy.Integer) and upper == sympy.oo: + upper = int_oo + if isinstance(upper, sympy.Integer) and lower == -sympy.oo: + lower = -int_oo + # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed + integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity) + is_int_lower = isinstance(lower, integer_types) + is_int_upper = isinstance(upper, integer_types) + # Because this is a frozen class + object.__setattr__(self, "lower", lower) + object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints + # + # NB: is_bool_lower == is_bool_upper, so we only need to check one + object.__setattr__(self, "is_bool", is_bool_lower) object.__setattr__( self, "is_int", - not self.is_bool - and ( - isinstance(lower, (sympy.Integer, NegativeIntInfinity)) - or isinstance(upper, (sympy.Integer, IntInfinity)) - ), + not self.is_bool and is_int_lower and is_int_upper, ) """ # This assert is just impossible right now, too many sympy bugs @@ -216,13 +232,15 @@ class ValueRanges(Generic[_T]): # Intersection @overload def __and__( - self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr] + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], ) -> ValueRanges[sympy.Expr]: ... @overload - def __and__( - self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean] + def __and__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], ) -> ValueRanges[SympyBoolean]: ... @@ -246,13 +264,15 @@ class ValueRanges(Generic[_T]): # Union @overload def __or__( - self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr] + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], ) -> ValueRanges[sympy.Expr]: ... @overload - def __or__( - self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean] + def __or__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], ) -> ValueRanges[SympyBoolean]: ... @@ -260,6 +280,8 @@ class ValueRanges(Generic[_T]): if ValueRanges.unknown() in (self, other): return ValueRanges.unknown() assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper) @@ -292,7 +314,7 @@ class ValueRanges(Generic[_T]): @overload @staticmethod - def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: + def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc] ... @staticmethod @@ -317,7 +339,7 @@ class ValueRanges(Generic[_T]): @overload @staticmethod - def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: + def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc] ... @staticmethod @@ -340,27 +362,36 @@ class ValueRanges(Generic[_T]): """Fn is convex and has a minimum at 0.""" x = ValueRanges.wrap(x) if 0 in x: - return ValueRanges(0, max(fn(x.lower), fn(x.upper))) - else: - return ValueRanges.monotone_map(x, fn) + upper = max(fn(x.lower), fn(x.upper)) + upper = simple_sympify(upper) + if isinstance(upper, sympy.Float) or upper == sympy.oo: + return ValueRanges(0.0, upper) + return ValueRanges(0, upper) + return ValueRanges.monotone_map(x, fn) @overload @staticmethod def coordinatewise_increasing_map( - x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2 + x: Union[ExprIn, ExprVR], + y: Union[ExprIn, ExprVR], + fn: ExprFn2, ) -> ExprVR: ... @overload @staticmethod - def coordinatewise_increasing_map( - x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2 + def coordinatewise_increasing_map( # type: ignore[misc] + x: Union[BoolIn, BoolVR], + y: Union[BoolIn, BoolVR], + fn: BoolFn2, ) -> BoolVR: ... @staticmethod def coordinatewise_increasing_map( - x: Union[AllIn, AllVR], y: Union[AllIn, AllVR], fn: AllFn2 + x: Union[AllIn, AllVR], + y: Union[AllIn, AllVR], + fn: AllFn2, ) -> AllVR: """ It's increasing on each coordinate. @@ -1035,20 +1066,20 @@ def bound_sympy( if unbounded_vars: # Give some bounds to the free variables via their SymPy assumptions # TODO A better way of doing this would be to assign them a range upon creation, as - # size variables can come with a lower bound of 2, as we specialise on 0 and 1 + # size variables can come with a lower bound of 2, as we specialize on 0 and 1 unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {} for s in unbounded_vars: if s.is_integer: # type: ignore[attr-defined] if s.is_positive: # type: ignore[attr-defined] - lower = 1 + vr = ValueRanges(1, int_oo) elif s.is_nonnegative: # type: ignore[attr-defined] - lower = 0 + vr = ValueRanges(0, int_oo) else: - lower = -math.inf # type: ignore[assignment] + vr = ValueRanges.unknown_int() else: # Don't bother trying very hard here - lower = -math.inf # type: ignore[assignment] - unbounded_ranges[s] = ValueRanges(lower, math.inf) # type: ignore[index] + vr = ValueRanges.unknown() + unbounded_ranges[s] = vr # type: ignore[index] ranges = {**ranges, **unbounded_ranges} return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)