diff --git a/mypy.ini b/mypy.ini index c13b026e2ba1..c306acbd944a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ # test_run_mypy in test/test_type_hints.py uses this string) [mypy] -plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin +plugins = mypy_plugins/check_mypy_version.py, mypy_plugins/sympy_mypy_plugin.py, numpy.typing.mypy_plugin cache_dir = .mypy_cache/normal allow_redefinition = True diff --git a/mypy_plugins/sympy_mypy_plugin.py b/mypy_plugins/sympy_mypy_plugin.py new file mode 100644 index 000000000000..b2ffce0f29d1 --- /dev/null +++ b/mypy_plugins/sympy_mypy_plugin.py @@ -0,0 +1,59 @@ +from mypy.plugin import Plugin +from mypy.plugins.common import add_attribute_to_class +from mypy.types import NoneType, UnionType + + +class SympyPlugin(Plugin): + def get_base_class_hook(self, fullname: str): + if fullname == "sympy.core.basic.Basic": + return add_assumptions + return None + + +def add_assumptions(ctx) -> None: + # Generated by list(sys.modules['sympy.core.assumptions']._assume_defined) + # (do not import sympy to speedup mypy plugin load time) + assumptions = [ + "hermitian", + "prime", + "noninteger", + "negative", + "antihermitian", + "infinite", + "finite", + "irrational", + "extended_positive", + "nonpositive", + "odd", + "algebraic", + "integer", + "rational", + "extended_real", + "nonnegative", + "transcendental", + "extended_nonzero", + "extended_negative", + "composite", + "complex", + "imaginary", + "nonzero", + "zero", + "even", + "positive", + "polar", + "extended_nonpositive", + "extended_nonnegative", + "real", + "commutative", + ] + for a in assumptions: + add_attribute_to_class( + ctx.api, + ctx.cls, + f"is_{a}", + UnionType([ctx.api.named_type("builtins.bool"), NoneType()]), + ) + + +def plugin(version: str): + return SympyPlugin diff --git a/test/test_typing.py b/test/test_typing.py index ba542e18f0d5..3793700a5c79 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -67,7 +67,7 @@ def _run_mypy() -> Dict[str, List[str]]: directory, ] ) - assert not stderr, directory + assert not stderr, stderr stdout = stdout.replace("*", "") # Parse the output diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index abbbdf3b6947..e37a5799ca0e 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -8,7 +8,7 @@ import operator import math import logging import torch -from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, cast, Union +from typing import Dict, Optional, SupportsFloat, TypeVar, Generic, Union, overload, Callable, TYPE_CHECKING from typing_extensions import TypeGuard from torch._prims_common import dtype_to_type @@ -70,8 +70,25 @@ def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: return not vr.is_bool +ExprIn = Union[int, float, sympy.Expr] +BoolIn = Union[bool, SympyBoolean] +AllIn = Union[ExprIn, BoolIn] +ExprFn = Callable[[sympy.Expr], sympy.Expr] +ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr] +BoolFn = Callable[[SympyBoolean], SympyBoolean] +BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean] +AllFn = Union[ExprFn, BoolFn] +AllFn2 = Union[ExprFn2, BoolFn2] + + @dataclasses.dataclass(frozen=True) class ValueRanges(Generic[_T]): + if TYPE_CHECKING: + # ruff doesn't understand circular references but mypy does + ExprVR = ValueRanges[sympy.Expr] # noqa: F821 + BoolVR = ValueRanges[SympyBoolean] # noqa: F821 + AllVR = Union[ExprVR, BoolVR] + # Although the type signature here suggests you can pass any # sympy expression, in practice the analysis here only works # with constant sympy expressions @@ -79,7 +96,15 @@ class ValueRanges(Generic[_T]): upper: _T is_bool: bool - def __init__(self, lower: Union[_T, bool, int, float], upper: Union[_T, bool, int, float]) -> None: + @overload + def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: + ... + + @overload + def __init__(self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn) -> None: + ... + + def __init__(self, lower: AllIn, upper: AllIn) -> None: lower = simple_sympify(lower) upper = simple_sympify(upper) # TODO: when the bounds have free variables, this may be @@ -92,15 +117,15 @@ class ValueRanges(Generic[_T]): object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) assert isinstance(upper, SympyBoolean) == self.is_bool - def boolify(self): - if self.is_bool: + def boolify(self) -> ValueRanges[SympyBoolean]: + if vr_is_bool(self): return self elif self == ValueRanges.unknown(): return ValueRanges.unknown_bool() else: raise AssertionError(f"not bool like {self}") - def __contains__(self, x): + def __contains__(self, x: AllIn) -> bool: x = simple_sympify(x) return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper) @@ -109,30 +134,42 @@ class ValueRanges(Generic[_T]): return self & other # Intersection - def __and__(self: ValueRanges[_T], other: ValueRanges[_T]) -> ValueRanges[_T]: + @overload + def __and__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]: + ... + + @overload + def __and__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]: + ... + + def __and__(self: AllVR, other: AllVR) -> AllVR: if other == ValueRanges.unknown(): return self if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) - if vr_is_bool(self): - return cast(ValueRanges[_T], ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper))) - elif vr_is_expr(self): - return cast(ValueRanges[_T], ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper))) + if self.is_bool: + return ValueRanges(sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)) else: - raise AssertionError("impossible") + return ValueRanges(sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)) # Union - def __or__(self, other) -> ValueRanges: + @overload + def __or__(self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr]) -> ValueRanges[sympy.Expr]: + ... + + @overload + def __or__(self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean]) -> ValueRanges[SympyBoolean]: + ... + + def __or__(self: AllVR, other: AllVR) -> AllVR: if ValueRanges.unknown() in (self, other): return ValueRanges.unknown() assert self.is_bool == other.is_bool, (self, other) - if vr_is_bool(self): - return cast(ValueRanges[_T], ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper))) - elif vr_is_expr(self): - return cast(ValueRanges[_T], ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper))) + if self.is_bool: + return ValueRanges(sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)) else: - raise AssertionError("impossible") + return ValueRanges(sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)) def is_singleton(self) -> bool: return self.lower == self.upper @@ -146,43 +183,76 @@ class ValueRanges(Generic[_T]): def unknown_bool() -> ValueRanges[SympyBoolean]: return ValueRanges(sympy.false, sympy.true) - @classmethod - def wrap(cls, arg): + @overload + @staticmethod + # work around the fact that bool and int overlap + def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap] + ... + + @overload + @staticmethod + def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: + ... + + @staticmethod + def wrap(arg: Union[AllIn, AllVR]) -> AllVR: if isinstance(arg, ValueRanges): return arg - return ValueRanges(arg, arg) + # arg is either ExprIn or BoolIn, but we don't know it here + return ValueRanges(arg, arg) # type: ignore[arg-type] - @classmethod - def increasing_map(cls, x, fn): + @staticmethod + def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: """Increasing: x <= y => f(x) <= f(y).""" - x = cls.wrap(x) + x = ValueRanges.wrap(x) return ValueRanges(fn(x.lower), fn(x.upper)) - @classmethod - def decreasing_map(cls, x, fn): - """Decreasing: x <= y => f(x) >= f(y).""" - x = cls.wrap(x) - return ValueRanges(fn(x.upper), fn(x.lower)) + @overload + @staticmethod + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + ... - @classmethod - def monotone_map(cls, x, fn): + @overload + @staticmethod + def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: + ... + + @staticmethod + def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR: + """Decreasing: x <= y => f(x) >= f(y).""" + x = ValueRanges.wrap(x) + # consistently either Expr or Bool, but we don't know it here + return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type] + + @staticmethod + def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: """It's increasing or decreasing.""" - x = cls.wrap(x) + x = ValueRanges.wrap(x) l = fn(x.lower) u = fn(x.upper) return ValueRanges(min(l, u), max(l, u)) - @classmethod - def convex_min_zero_map(cls, x, fn): + @staticmethod + def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: """Fn is convex and has a minimum at 0.""" x = ValueRanges.wrap(x) if 0 in x: return ValueRanges(0, max(fn(x.lower), fn(x.upper))) else: - return cls.monotone_map(x, fn) + return ValueRanges.monotone_map(x, fn) - @classmethod - def coordinatewise_increasing_map(cls, x, y, fn): + @overload + @staticmethod + def coordinatewise_increasing_map(x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2) -> ExprVR: + ... + + @overload + @staticmethod + def coordinatewise_increasing_map(x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2) -> BoolVR: + ... + + @staticmethod + def coordinatewise_increasing_map(x: Union[AllIn, AllVR], y: Union[AllIn, AllVR], fn: AllFn2) -> AllVR: """ It's increasing on each coordinate. @@ -190,10 +260,10 @@ class ValueRanges(Generic[_T]): For every 1 <= i <= n and x_i <= y_i we have that f(x1, .., xn) <= f(x1, , yi, ..., xn) """ - x, y = cls.wrap(x), cls.wrap(y) + x, y = ValueRanges.wrap(x), ValueRanges.wrap(y) return ValueRanges( - fn(x.lower, y.lower), - fn(x.upper, y.upper), + fn(x.lower, y.lower), # type: ignore[arg-type] + fn(x.upper, y.upper), # type: ignore[arg-type] ) @classmethod @@ -450,7 +520,7 @@ class SymPyValueRangeAnalysis: b = ValueRanges.wrap(b) # Performs upcasting first - def fn_(x, y): + def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: # Poorman's version of upcasting in Sympy # Inf is not a float... if x.is_Integer and y.is_Integer: