From cc8f1cddd4bb7c8e59d6bb11f9777652547dc32f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 30 Sep 2024 18:21:41 -0700 Subject: [PATCH] Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972 Approved by: https://github.com/Skylion007 ghstack dependencies: #136934, #136935 --- mypy_plugins/sympy_mypy_plugin.py | 8 + torch/_dynamo/output_graph.py | 3 +- torch/_guards.py | 4 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/graph.py | 6 +- torch/_inductor/ir.py | 6 +- torch/_inductor/lowering.py | 1 + torch/fx/experimental/symbolic_shapes.py | 453 ++++++++++++++--------- torch/fx/passes/runtime_assert.py | 9 +- torch/utils/_python_dispatch.py | 11 + torch/utils/_sympy/solve.py | 2 +- 11 files changed, 318 insertions(+), 189 deletions(-) diff --git a/mypy_plugins/sympy_mypy_plugin.py b/mypy_plugins/sympy_mypy_plugin.py index b2ffce0f29d1..9432963ad8f1 100644 --- a/mypy_plugins/sympy_mypy_plugin.py +++ b/mypy_plugins/sympy_mypy_plugin.py @@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType class SympyPlugin(Plugin): def get_base_class_hook(self, fullname: str): + # TODO: This apparently never worked if fullname == "sympy.core.basic.Basic": return add_assumptions return None + def get_attribute_hook(self, fullname: str): + if fullname == "sympy.core.basic.Basic.free_symbols": + return lambda ctx: ctx.api.named_generic_type( + "builtins.set", [ctx.api.named_type("sympy.Symbol")] + ) + return None + def add_assumptions(ctx) -> None: # Generated by list(sys.modules['sympy.core.assumptions']._assume_defined) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ed429f7cab00..33fb49192fcb 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -660,7 +660,7 @@ class OutputGraph: assert arg.fake_tensor is not None - def bind_symint(s, prop): + def bind_symint(s: torch.SymInt, prop): if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): return s0 = s.node.expr @@ -677,6 +677,7 @@ class OutputGraph: source=prop, ) set_example_value(proxy.node, s) + assert isinstance(s, torch.SymInt) proxy.node.meta["grapharg"] = GraphArg( prop, s, diff --git a/torch/_guards.py b/torch/_guards.py index 3dcb3b13e0a1..f6bd852d26d4 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -154,7 +154,7 @@ class GuardBuilderBase: @dataclasses.dataclass(frozen=True) class SLoc: - framework_loc: Union[traceback.FrameSummary, str] + framework_loc: Optional[Union[traceback.FrameSummary, str]] maybe_user_loc: Optional[str] def __str__(self): @@ -170,7 +170,7 @@ class SLoc: class ShapeGuard(NamedTuple): - expr: sympy.Expr + expr: sympy.logic.boolalg.Boolean sloc: SLoc diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index b25875f56932..b63594f202c5 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2044,7 +2044,9 @@ class PythonWrapperCodegen(CodeGen): if isinstance(x, int): return x val = V.graph._shape_env._maybe_evaluate_static(x) - return int(val) + if val is None: + return val + return int(val) # type: ignore[call-overload] except Exception: return None diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 75159ef2c6a6..9e3a5d383c6d 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -45,6 +45,7 @@ from torch.fx.experimental.symbolic_shapes import ( resolve_unbacked_bindings, RuntimeAssert, ShapeEnv, + SympyBoolean, SymTypes, ) from torch.fx.graph import Graph @@ -352,7 +353,7 @@ class GraphLowering(torch.fx.Interpreter): shape_env.freeze_runtime_asserts() # We're going to mutate ras_by_symbol as we finish generating them self.ras_by_symbol: Dict[ - sympy.Symbol, List[RuntimeAssert] + Optional[sympy.Symbol], List[RuntimeAssert] ] = shape_env.deferred_runtime_asserts.copy() self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet() self.sizevars = SizeVarAllocator(shape_env) @@ -1595,7 +1596,7 @@ class GraphLowering(torch.fx.Interpreter): # This is all doable, it just hasn't been done yet. shape_env = V.graph.sizevars.shape_env - def make_assert(expr: Expr, msg: str) -> None: + def make_assert(expr: SympyBoolean, msg: str) -> None: assert_op = ir.AssertScalar(expr, msg) self.register_buffer(assert_op, set_name=True) self.register_operation(assert_op) @@ -1634,6 +1635,7 @@ class GraphLowering(torch.fx.Interpreter): unbacked_bindings = resolve_unbacked_bindings( V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) ) + assert unbacked_bindings is not None # When we do lowering, it is possible we reallocate unbacked SymInts. # So we need to line up the unbacked SymInts when performing the test # here diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0c704092ed3d..9d9bea7a6fe9 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5913,9 +5913,11 @@ class FallbackKernel(ExternKernelAlloc): def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: if unbacked_bindings := getattr(self, "unbacked_bindings", None): - return resolve_unbacked_bindings( + resolved = resolve_unbacked_bindings( V.graph.sizevars.shape_env, unbacked_bindings - ).keys() + ) + assert resolved is not None + return resolved.keys() # type: ignore[return-value] else: return OrderedSet() diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index b0d02b6b0f65..1c20373a5cf7 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2676,6 +2676,7 @@ def _local_scalar_dense(data): unbacked_bindings = resolve_unbacked_bindings( V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] ) + assert unbacked_bindings is not None assert len(unbacked_bindings) == 1, unbacked_bindings # NB: Have to be very careful here. V.graph.current_node.meta["val"] # seemingly also contains a symbol which you want to do binding for, diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 007cf72ce0d7..61d400e15c5a 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1,4 +1,7 @@ -# mypy: ignore-errors +# mypy: allow-untyped-defs + +from __future__ import annotations + """ ``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with @@ -8,7 +11,6 @@ as well as extensions to PyTorch (e.g., in custom operator implementations), you need to make use of these APIs to setup dynamic shapes support appropriately. """ - import atexit import builtins import collections @@ -31,8 +33,10 @@ from typing import ( Any, Callable, cast, + Counter, + DefaultDict, Dict, - Iterable, + Iterator, List, Optional, Sequence, @@ -40,9 +44,10 @@ from typing import ( Tuple, Type, TYPE_CHECKING, + TypeVar, Union, ) -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypeGuard import torch import torch.fx @@ -89,7 +94,8 @@ from torch.utils._traceback import CapturedTraceback, format_frame if TYPE_CHECKING: - from torch._dynamo.source import TensorPropertySource + import types + InputList = List DimList = List @@ -97,14 +103,15 @@ DimList = List log = logging.getLogger(__name__) import sympy +from sympy import S from sympy.printing.precedence import PRECEDENCE, precedence from sympy.printing.str import StrPrinter class GuardOnDataDependentSymNode(RuntimeError): - cond: sympy.Expr + cond: sympy.Basic - def __init__(self, cond, *args): + def __init__(self, cond: sympy.Basic, *args: Any) -> None: super().__init__(*args) self.cond = cond @@ -154,14 +161,19 @@ SHAPEENV_EVENT_KEY = "shapeenv_event" CURRENT_NODE_KEY = "current_node" -def log_lru_cache_stats(wrapped_f): +def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: log.debug( - "lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() + "lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined] ) +_T = TypeVar("_T") + + # Wrapper on lru_cache that reports statistics at process end -def lru_cache(maxsize): +def lru_cache( + maxsize: Optional[int], +) -> Callable[[Callable[..., _T]], functools._lru_cache_wrapper[_T]]: def inner(f): wrapped_f = functools.lru_cache(maxsize)(f) old_cache_clear = wrapped_f.cache_clear @@ -172,7 +184,7 @@ def lru_cache(maxsize): # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not # weakref'able on some versions of Python - def cumulative_cache_info(): + def cumulative_cache_info() -> functools._CacheInfo: cur = wrapped_f.cache_info() return functools._CacheInfo( prev_hits + cur.hits, @@ -181,15 +193,15 @@ def lru_cache(maxsize): cur.currsize, ) - def new_cache_clear(): + def new_cache_clear() -> None: nonlocal prev_hits, prev_misses cur = wrapped_f.cache_info() prev_hits += cur.hits prev_misses += cur.misses old_cache_clear() - wrapped_f.cache_clear = new_cache_clear - wrapped_f.cumulative_cache_info = cumulative_cache_info + wrapped_f.cache_clear = new_cache_clear # type: ignore[attr-defined, method-assign] + wrapped_f.cumulative_cache_info = cumulative_cache_info # type: ignore[attr-defined, method-assign] if log.isEnabledFor(logging.DEBUG): atexit.register(log_lru_cache_stats, wrapped_f) return wrapped_f @@ -231,17 +243,17 @@ class ConstraintViolationError(RuntimeError): pass -def has_symbolic_sizes_strides(elem) -> bool: +def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool: return elem._has_symbolic_sizes_strides -Int = Union[torch.SymInt, int] +Int: TypeAlias = Union[torch.SymInt, int] def create_contiguous(shape: Sequence[Int]) -> List[Int]: strides: List[Int] = [1] for dim in reversed(shape[:-1]): - strides.append(dim * strides[-1]) + strides.append(dim * strides[-1]) # type: ignore[operator] return list(reversed(strides)) @@ -257,7 +269,7 @@ def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int return a -Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] +Scalar: TypeAlias = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] def has_hint(a: Scalar) -> bool: @@ -285,6 +297,17 @@ def is_concrete_int(a: Union[int, SymInt]) -> bool: return False +# Note about Sympy Expr/SympyBoolean/Basic typing: the Sympy hierarchy is +# +# Basic +# Expr +# SympyBoolean +# Relational +# +# Notably, Expr and SympyBoolean are not related. So use Basic when the +# expression could denote int, float OR bool, and otherwise use the more +# specific Expr for int/float and SympyBoolean for bool. +# # In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. # So make sure only type checker evaluates this alias. # Xref: https://www.internalfb.com/diff/D53324783 @@ -336,13 +359,22 @@ def check_consistent(new, old) -> None: torch._check(old == new, lambda: f"{old} != {new} (old != new)") -def resolve_unbacked_bindings(shape_env, bindings): +def resolve_unbacked_bindings( + shape_env: Optional[ShapeEnv], + bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]], +) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]: if bindings is None: return None + assert shape_env is not None return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} -def rebind_unbacked(shape_env, n: torch.fx.Node, result): +Result: TypeAlias = Union[torch.Tensor, Tuple[torch.Tensor, ...]] + + +def rebind_unbacked( + shape_env: Optional[ShapeEnv], n: torch.fx.Node, result: Result +) -> None: """ Suppose we are retracing a pre-existing FX graph that previously had fake tensor propagation (and therefore unbacked SymInts). When we retrace, @@ -361,6 +393,7 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result): if bindings := resolve_unbacked_bindings( shape_env, n.meta.get("unbacked_bindings") ): + assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) # tensor_version ops get specialized after AOTAutograd, it's OK, @@ -379,12 +412,17 @@ def rebind_unbacked(shape_env, n: torch.fx.Node, result): if ( isinstance(raw_u1, sympy.Piecewise) and len(raw_u1.args) == 2 - and raw_u1.args[0][0] == 1 - and isinstance(eq := raw_u1.args[0][1], sympy.Eq) + and ( + raw_u1_args0 := cast( + Tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] + ) + ) + and raw_u1_args0[0] == 1 + and isinstance(eq := raw_u1_args0[1], sympy.Eq) and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and eq.rhs == 1 - and raw_u1.args[1] == (0, True) + and cast(Tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) ): # This is what the pattern match above is testing repacked = _sympy_cast_symbool_to_symint_guardless( @@ -410,6 +448,7 @@ def is_accessor_node(node: torch.fx.Node) -> bool: # Dynamo only exercised condition if ( node.op == "call_method" + and isinstance(node.args[0], torch.fx.Node) and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) and node.target in ["size", "stride", "storage_offset", "item"] ): @@ -429,7 +468,7 @@ def is_accessor_node(node: torch.fx.Node) -> bool: return False -def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: +def canonicalize_bool_expr(expr: _T) -> _T: r"""Canonicalize a boolean expression by transforming it into a lt / le inequality and moving all the non-constant terms to the rhs. We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr @@ -452,17 +491,17 @@ def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): expr = sympy.logic.boolalg.to_cnf(expr) - return _canonicalize_bool_expr_impl(expr) + return _canonicalize_bool_expr_impl(expr) # type: ignore[arg-type, return-value] def _sympy_from_args( - cls: type, + cls: Union[Type[sympy.Add], Type[sympy.Mul]], args: List[sympy.Expr], sort: bool = True, is_commutative: Optional[bool] = None, ) -> sympy.Expr: if not args: - return cls.identity + return cls.identity # type: ignore[union-attr] # These args are already in canonical form, so we avoid calling # Add(*args) to avoid expensive Add.flatten operation if sort: @@ -478,14 +517,14 @@ def _sympy_from_args( if args[0].is_Number: rest = args[1:] sort_fn(rest) - return cls._from_args([args[0]] + rest, is_commutative=is_commutative) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) # type: ignore[attr-defined] else: args = args.copy() sort_fn(args) - return cls._from_args(args, is_commutative=is_commutative) + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] else: # if the args are already sorted, we create directly - return cls._from_args(args, is_commutative=is_commutative) + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: @@ -497,9 +536,10 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: return type(expr)(*map(canonicalize_bool_expr, expr.args)) opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + t: Union[Type[Any]] if isinstance(expr, tuple(opposite.keys())): - rhs = expr.lhs - expr.rhs - t = opposite[type(expr)] + rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] + t = opposite[type(expr)] # type: ignore[index] else: assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) rhs = expr.rhs - expr.lhs @@ -510,7 +550,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative ) - lhs = 0 + lhs = S.Zero rhs = _reduce_to_lowest_terms(rhs) if isinstance(rhs, sympy.Add): pos = [] @@ -526,7 +566,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) elif is_neg(rhs): # lhs == 0 - lhs, rhs = -rhs, 0 + lhs, rhs = -rhs, S.Zero # We don't have to evaluate here because lhs, rhs came from a Boolean # and it was already simplified return t(lhs, rhs, evaluate=False) @@ -572,7 +612,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative ) elif expr.is_Integer: - return sympy.One + return S.One elif expr.is_Mul: return div_by_factor(expr, integer_coefficient(expr)) return expr @@ -602,7 +642,15 @@ def is_nested_int(s): return isinstance(s, torch.SymInt) and s.node.is_nested_int() -def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]: +IterateExprsAtom: TypeAlias = Union[ + SymInt, SymFloat, SymBool, int, float, bool, sympy.Basic, torch.Tensor +] +IterateExprs: TypeAlias = Union[ + IterateExprsAtom, Tuple[IterateExprsAtom, ...], List[IterateExprsAtom] +] + + +def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: if isinstance(val, SymTypes): # This allow applies to the jagged layout NestedTensor case as # nested ints are not symbolic @@ -686,7 +734,7 @@ class ConvertIntKey: def __str__(self) -> str: return ".cast_symbool_to_symint_guardless()" - def get(self, b: bool) -> int: + def get(self, b: bool) -> Union[int, SymInt]: """Get the int value from bool""" return cast_symbool_to_symint_guardless(b) @@ -776,6 +824,9 @@ def compute_unbacked_bindings( ) ) elif isinstance(a, torch.Tensor): + from torch._subclasses.fake_tensor import FakeTensor + + assert isinstance(a, FakeTensor) r.update( free_unbacked_symbols_with_path( a.size(), @@ -1181,8 +1232,12 @@ def expect_true(a, skip: int = 0): # TODO: check perf implications of this frame = inspect.currentframe() for _ in range(skip + 1): # always run this loop at least once + if frame is None: + break frame = frame.f_back - return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno) + return a.node.expect_true( + frame.f_code.co_filename if frame else "", frame.f_lineno if frame else 0 + ) assert type(a) is bool, a return a @@ -1372,6 +1427,9 @@ class EqualityConstraint(Constraint): ] phantom_symbols: List[sympy.Symbol] + _parents: Dict[Source, Source] = field(init=False) + _defs: Dict[Source, sympy.Expr] = field(init=False) + def __post_init__(self): """Pre-processing to answer queries `is_equal` and `is_derived` below. @@ -1500,9 +1558,9 @@ class StatelessSymbolicContext(SymbolicContext): """ dynamic_sizes: DimList[DimDynamic] - dynamic_strides: DimList[DimDynamic] = None - constraint_sizes: DimList[DimConstraint] = None - constraint_strides: DimList[DimConstraint] = None + dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] + constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] + constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] # If the tensor is a view, this should be populated for the base. It contains # information on how to allocate symbols when recursively fakeifying the base # during view fake-ification. @@ -1567,7 +1625,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext): w/r/t different shape_envs, clearing, etc. """ - tensor_source: Source = None + tensor_source: Source = None # type: ignore[assignment] # Why is this keyd on int first? # That integer is actually the id of the shape_env. This cache short-circuits symbol # creation, and we must store it per shape env. Now, while tracing invariants are a single @@ -1577,9 +1635,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext): # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never # get recorded in var_to_val, etc. # TODO(voz): consider a weakref to the shape_env here - shape_env_to_source_to_symbol_cache: Dict[ - int, Dict["TensorPropertySource", "sympy.Expr"] - ] = None + shape_env_to_source_to_symbol_cache: Dict[int, Dict[str, sympy.Expr]] = None # type: ignore[assignment] def __post_init__(self): super().__post_init__() @@ -1597,7 +1653,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext): flexibility, with inner symbolic contexts mapped via attr -> symbolic context. """ - inner_contexts: Dict[str, SymbolicContext] = None + inner_contexts: Dict[str, SymbolicContext] = None # type: ignore[assignment] def __post_init__(self): super().__post_init__() @@ -1605,7 +1661,9 @@ class SubclassSymbolicContext(StatefulSymbolicContext): self.inner_contexts = {} -def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: +def is_symbolic( + val: Union[int, SymInt, float, SymFloat, bool, SymBool] +) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]: if isinstance(val, (int, float, bool)): return False return val.node.is_symbolic() @@ -1637,24 +1695,27 @@ def _fast_expand(expr: sympy.Expr) -> sympy.Expr: # such features here to avoid expensive checks. We also make sure that we # only re-create the objects if any of the args changed to avoid expensive # checks when re-creating objects. - new_args = [_fast_expand(arg) for arg in expr.args] + new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): return _fast_expand(expr.func(*new_args)) if expr.is_Pow: - base, exp = expr.args + base: sympy.Expr + exp: sympy.Expr + base, exp = expr.args # type: ignore[assignment] if exp.is_Integer and base.is_Add: if exp > 1: return sympy.expand_multinomial(expr, deep=False) elif exp < 0: return 1 / sympy.expand_multinomial(1 / expr, deep=False) elif expr.is_Mul: - num, den = [], [] + num: List[sympy.Expr] = [] + den: List[sympy.Expr] = [] for arg in expr.args: if arg.is_Pow and arg.args[1] == -1: - den.append(1 / arg) + den.append(1 / arg) # type: ignore[operator, arg-type] else: - num.append(arg) + num.append(arg) # type: ignore[arg-type] num, num_changed = _expandsums(num) den, den_changed = _expandsums(den) @@ -1748,6 +1809,7 @@ def _maybe_evaluate_static_worker( new_shape_env[k] = s + offset new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + # TODO: remove this try catch (esp for unbacked_only) try: new_expr = expr.xreplace(new_shape_env) except RecursionError: @@ -1806,11 +1868,13 @@ def _eval_is_non_overlapping_and_dense(sizes, strides): return True -def _sympy_cast_symbool_to_symint_guardless(x: sympy.Expr) -> sympy.Expr: +def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr: return sympy.Piecewise((1, x), (0, True)) -def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: +def cast_symbool_to_symint_guardless( + symbool: Union[bool, torch.SymBool] +) -> Union[int, torch.SymInt]: if isinstance(symbool, bool): return 1 if symbool else 0 int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) @@ -1850,7 +1914,9 @@ SYMPY_INTERP = { } -def _lru_cache(fn, maxsize=None): +def _lru_cache( + fn: Callable[..., _T], maxsize: Optional[int] = None +) -> functools._lru_cache_wrapper[_T]: """ Wrapper around lru_cache that clears when new info about shapes has been updated. @@ -1898,9 +1964,9 @@ def _lru_cache(fn, maxsize=None): return fn_cache(self, *args, **kwargs) - wrapper.cache_clear = fn_cache.cache_clear + wrapper.cache_clear = fn_cache.cache_clear # type: ignore[attr-defined] wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] - return wrapper + return wrapper # type: ignore[return-value] # This is pretty similar to ShapeGuard but it also comes with a message, @@ -1909,9 +1975,9 @@ def _lru_cache(fn, maxsize=None): # a particular specialization) @dataclass(frozen=True) class RuntimeAssert: - expr: sympy.Expr + expr: SympyBoolean msg: str = field(repr=False) - stack: str = field(repr=False) + stack: CapturedTraceback = field(repr=False) # Used for printing SymExprs in compile_fx @@ -2022,7 +2088,7 @@ class DimConstraints: # We do so by using the values of variables as hints to evaluate %. # For soundness we record additional congruence guards and solve them separately. self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val - self._congruences: Set[sympy.Expr] = defaultdict(set) + self._congruences: DefaultDict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) # We do not try to (directly) solve inequalities with > 1 free variables. # NOTE: free variables in these inequalities cannot also be in _substitutions. @@ -2280,7 +2346,7 @@ class DimConstraints: f"{self._dcp.symbol_to_source[s][0].name()} == {val}" ) # add this as a substitution to simplify other constraints - self._substitutions[s] = val + self._substitutions[s] = val # type: ignore[assignment] # simplify multivariate inequalities: some of them will now become univariate! multivariate_inequalities = self._multivariate_inequalities @@ -2306,6 +2372,7 @@ class DimConstraints: self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] r = try_solve(sympy.Eq(base, divisor * tmp), s) + assert r is not None self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) # remaining symbols have only pure inequalities (no equalities) @@ -2506,19 +2573,19 @@ class DimConstraints: # this is now either 1) unchanged, 2) refined with a new range, # or 3) specialized to a concrete value modified_root_values: Dict[str, Dict[str, Any]] = {} - for root in modified_roots: + for mroot in modified_roots: swapped_root = True - if root in results: - c = results[root] + if mroot in results: + c = results[mroot] if ("min" in c or "max" in c) or isinstance( # range c["eq"], int ): # specialized # here, the original root is a root Dim or concrete value in results. # if it is a derived dim, it is swapped, and we handle that below. if not _check_same_range( - c, name_to_dim[root] + c, name_to_dim[mroot] ): # ignore if unchanged - modified_root_values[root] = c + modified_root_values[mroot] = c swapped_root = False if swapped_root: @@ -2531,20 +2598,20 @@ class DimConstraints: dim = name_to_dim[k] if ( dim.__class__.__name__ == "_DerivedDim" - and dim.root.__name__ == root + and dim.root.__name__ == mroot ): # only look for min/max root, otherwise root would have specialized if "min" in c or "max" in c: expr = sympy.sympify(k) s = next(iter(expr.free_symbols)) result = { - "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type] - "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type] + "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type, index] + "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] } if not _check_same_range( - result, name_to_dim[root] + result, name_to_dim[mroot] # type: ignore[index] ): # ignore if unchanged - modified_root_values[root] = result + modified_root_values[mroot] = result # type: ignore[index] break # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) @@ -2582,7 +2649,7 @@ class DimConstraints: s = s.replace(k, v) if not inverse else s.replace(v, k) return s - results = defaultdict(dict) + results: DefaultDict[str, Dict[str, Any]] = defaultdict(dict) if dynamic_shapes is None: dynamic_shapes = {} @@ -2662,14 +2729,14 @@ class DimConstraints: others = [] # order results by source name - results = { + results2 = { k: results[k] for k in sorted( results.keys(), key=lambda x: transform(x, inverse=True), ) } - for k, c in results.items(): + for k, c in results2.items(): if "eq" in c: other = c["eq"] if isinstance(other, int): @@ -2690,7 +2757,7 @@ class DimConstraints: else: dims.append(f"{k} = Dim('{k}')") - # results will get filtered out if no new suggestions, + # results2 will get filtered out if no new suggestions, # this can happen if guards are too complex. # in that case don't suggest fix if dims or others: @@ -2890,7 +2957,7 @@ class ShapeEnv: self.size_like: Set[sympy.Symbol] = set() # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable - self.val_to_var: Dict[int, sympy.Expr] = {} + self.val_to_var: Dict[int, sympy.Symbol] = {} if specialize_zero_one: self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)} self.unbacked_symfloat_counter = itertools.count() @@ -2923,7 +2990,9 @@ class ShapeEnv: # to the next unbacked symbol to wait on, but if we choose the # latest key, an assert will only show up at the moment when # we can actually codegen it. - self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {} + self.deferred_runtime_asserts: Dict[ + Optional[sympy.Symbol], List[RuntimeAssert] + ] = {} # This exists so we can efficiently invalidate the cache (it's used as # part of the cache key); otherwise we'd have to iterate through # deferred_runtime_asserts to compute its length @@ -2933,10 +3002,10 @@ class ShapeEnv: self.frozen = False self.runtime_asserts_frozen = False self.dim_constraints: Optional[DimConstraints] = None - self.counter = collections.Counter() + self.counter: Counter[str] = collections.Counter() # Mapping from sympy.Symbol to the number of guards which mention this # symbol - self.symbol_guard_counter = collections.Counter() + self.symbol_guard_counter: Counter[sympy.Symbol] = collections.Counter() # A selection of important fields on co_field; solely used for # signpost_event self.co_fields = co_fields if co_fields else {} @@ -3052,7 +3121,7 @@ class ShapeEnv: def allow_complex_guards_as_runtime_asserts(self): return self.settings.allow_complex_guards_as_runtime_asserts - def check_equal(self, other: "ShapeEnv") -> None: + def check_equal(self, other: ShapeEnv) -> None: """Compare another ShapeEnv for equivalence""" # ShapeEnv fields that are not relevant for the outcome of # ShapeEnv.produce_guards call: @@ -3440,7 +3509,7 @@ class ShapeEnv: def _produce_dyn_sizes_from_int_tuple( self, - tensor_size: Tuple[int], + tensor_size: Sequence[int], source: Source, symbolic_context: SymbolicContext, ) -> List[sympy.Expr]: @@ -3450,8 +3519,8 @@ class ShapeEnv: from torch._dynamo.source import TensorProperty, TensorPropertySource _assert_symbol_context(symbolic_context) - dynamic_dims = symbolic_context.dynamic_sizes - constraint_dims = symbolic_context.constraint_sizes + dynamic_dims = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined] size = [] for i, val in enumerate(tensor_size): size.append( @@ -3531,7 +3600,7 @@ class ShapeEnv: # The order of checking the guards matters. In this specific example: # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, # we may have an unnessary shape speciliazation for y. - def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int: + def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> Union[int, SymInt]: assert isinstance(maybe_sym, (int, torch.SymInt)) if is_symbolic(maybe_sym): assert ( @@ -3555,8 +3624,8 @@ class ShapeEnv: # Reimplement the legacy behavior if symbolic_context is None: - constraint_sizes = [None] * dim - constraint_strides = [None] * dim + constraint_sizes: List[DimConstraint] = [None] * dim + constraint_strides: List[DimConstraint] = [None] * dim dynamic_dims = [] dynamic_strides = [] for i in range(dim): @@ -3581,10 +3650,10 @@ class ShapeEnv: ) # We got a StatelessSymbolicContext _assert_symbol_context(symbolic_context) - constraint_sizes = symbolic_context.constraint_sizes - constraint_strides = symbolic_context.constraint_strides - dynamic_sizes = symbolic_context.dynamic_sizes - dynamic_strides = symbolic_context.dynamic_strides + constraint_sizes = symbolic_context.constraint_sizes # type: ignore[attr-defined] + constraint_strides = symbolic_context.constraint_strides # type: ignore[attr-defined] + dynamic_sizes = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + dynamic_strides = symbolic_context.dynamic_strides # type: ignore[attr-defined] # TODO: make this configurable from outside symbolic_context; we made a symbolic_context # decision here where if all sizes are static, we are going to @@ -3705,11 +3774,11 @@ class ShapeEnv: @record_shapeenv_event() def create_symintnode( self, - sym: "sympy.Expr", + sym: sympy.Expr, *, hint: Optional[int], source: Optional[Source] = None, - ): + ) -> Union[int, SymInt]: """Create a SymInt value from a symbolic expression If you know what the current hint value of the SymInt to be created @@ -3732,6 +3801,7 @@ class ShapeEnv: else: fx_node = None + out: Union[int, SymInt] if isinstance(sym, sympy.Integer): if hint is not None: assert int(sym) == hint @@ -3750,11 +3820,11 @@ class ShapeEnv: @record_shapeenv_event() def create_symfloatnode( self, - sym: "sympy.Expr", + sym: sympy.Expr, *, hint: Optional[int], source: Optional[Source] = None, - ): + ) -> Union[float, SymFloat]: """Create a SymFloat value from a symbolic expression""" source_name = source.name() if source else None @@ -3771,6 +3841,7 @@ class ShapeEnv: else: fx_node = None + out: Union[float, SymFloat] if isinstance(sym, sympy.Float): if hint is not None: assert float(sym) == hint @@ -3798,7 +3869,7 @@ class ShapeEnv: source=source, ) - def create_symboolnode(self, sym: "sympy.Expr"): + def create_symboolnode(self, sym: sympy.Expr): """Create a SymBool object from a sympy boolean expression""" # This function is only being used in serialization, so we do not track it # for validation. @@ -3897,7 +3968,7 @@ class ShapeEnv: source: Source, dynamic_dim: DimDynamic = DimDynamic.DUCK, constraint_dim: DimConstraint = None, # NB: includes None - ) -> "sympy.Expr": + ) -> sympy.Expr: """Create a symbol with an unspecified value Compared to standard symbols we do not assume the value is positive, @@ -3928,7 +3999,7 @@ class ShapeEnv: positive: Optional[bool] = True, do_not_specialize_zero_one: bool = False, symbolic_context=None, - ) -> "sympy.Expr": + ) -> sympy.Expr: """Create a new symbol which is tracked by this ShapeEnv""" # check if constraint_dim is actually static integer if ( @@ -3942,6 +4013,9 @@ class ShapeEnv: f"for {source.name()}" ) if symbolic_context: + from torch._dynamo.source import TensorPropertySource + + assert isinstance(source, TensorPropertySource) symbolic_context.dynamic_sizes[source.idx] = dynamic_dim symbolic_context.constraint_sizes[source.idx] = None constraint_dim = None @@ -4186,7 +4260,7 @@ class ShapeEnv: sources, source_ref=lambda n: n.name(), *, - guards: List[ShapeGuard] = None, + guards: Optional[List[ShapeGuard]] = None, input_contexts: Optional[DimList[SymbolicContext]] = None, # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). # (See docs on EqualityConstraint for details of the encoding.) @@ -4324,7 +4398,9 @@ class ShapeEnv: input_guards = [] symbol_to_source = collections.defaultdict(list) - symbol_to_constraints = collections.defaultdict(set) + symbol_to_constraints: DefaultDict[ + sympy.Symbol, Set[Constraint] + ] = collections.defaultdict(set) constraint_violations: List[Tuple[bool, str, Callable[[], str]]] = [] def record_constraint_violation(warn_only, debug_name, msg, hint=None): @@ -4522,7 +4598,8 @@ class ShapeEnv: # For subclasses, we need to track symints on BOTH the outer # and inner tensors. - sources_tensors_constraints = [ + # TODO: type this better + sources_tensors_constraints: List[Tuple[Source, Any, Any, Any]] = [ (source, t, context.constraint_sizes, context.constraint_strides) ] attrs, _ = t.__tensor_flatten__() @@ -4533,13 +4610,13 @@ class ShapeEnv: ( AttrSource(source, attr), inner_t, - inner_context.constraint_sizes, - inner_context.constraint_strides, + inner_context.constraint_sizes, # type: ignore[attr-defined] + inner_context.constraint_strides, # type: ignore[attr-defined] ) ) else: sources_tensors_constraints = [ - (source, t, context.constraint_sizes, context.constraint_strides) + (source, t, context.constraint_sizes, context.constraint_strides) # type: ignore[attr-defined] ] for ( @@ -4705,6 +4782,7 @@ class ShapeEnv: for s in expr.free_symbols for source in symbol_to_source[s] ): + assert self.dim_constraints is not None is_trivial = self.dim_constraints.add(expr) guard_expr = ShapeGuardPrinter( symbol_to_source, source_ref, self.var_to_sources @@ -4827,22 +4905,22 @@ class ShapeEnv: ) if constraint_violations: - warn_msgs = [] - error_msgs = [] + warn_msgs: List[str] = [] + error_msgs: List[str] = [] debug_names = set() - for warn_only, debug_name, msg in constraint_violations: + for warn_only, debug_name, msg_cb in constraint_violations: if warn_only: - msg = f" {len(warn_msgs) + 1}. {msg()}" - warn_msgs.append(msg) + str_msg = f" {len(warn_msgs) + 1}. {msg_cb()}" + warn_msgs.append(str_msg) else: - msg = f" - {msg()}" - error_msgs.append(msg) + str_msg = f" - {msg_cb()}" + error_msgs.append(str_msg) debug_names.add(debug_name) if len(error_msgs) > 0: - debug_names = ", ".join(sorted(debug_names)) + debug_names_str = ", ".join(sorted(debug_names)) err = "\n".join(error_msgs) raise ConstraintViolationError( - f"Constraints violated ({debug_names})! " + f"Constraints violated ({debug_names_str})! " 'For more information, run with TORCH_LOGS="+dynamic".\n' f"{err}" ) @@ -5029,6 +5107,7 @@ class ShapeEnv: self, expr: sympy.Expr, size_oblivious: bool = False ) -> ValueRanges: """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + # TODO: maybe it's guaranteed x in is var_to_range? var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} if size_oblivious: # Clamp values of size-like variables @@ -5040,15 +5119,14 @@ class ShapeEnv: # to determine if we can do size-like replacement, the # upper bound is irrelevant here var_to_range[x] = ValueRanges(2, int_oo) - assert var_to_range[x].is_int - return bound_sympy(expr, var_to_range) + return bound_sympy(expr, var_to_range) # type: ignore[arg-type] @_lru_cache def get_axioms( self, - symbols: Optional[Tuple["sympy.Symbol"]] = None, + symbols: Optional[Tuple[sympy.Symbol]] = None, compute_hint: bool = False, - ) -> Tuple["sympy.Expr"]: + ) -> Tuple[SympyBoolean, ...]: """ Given the symbols in an expression, it returns all the runtime asserts that have those symbols concatenated with all the guards. @@ -5065,8 +5143,8 @@ class ShapeEnv: if s not in self.var_to_val for r in self.deferred_runtime_asserts.get(s, ()) ) - guards = (g.expr for g in self.guards) - axioms = itertools.chain(guards, runtime_asserts) + guards: Iterator[SympyBoolean] = (g.expr for g in self.guards) + axioms: Iterator[SympyBoolean] = itertools.chain(guards, runtime_asserts) if compute_hint: axioms = ( canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms @@ -5075,10 +5153,10 @@ class ShapeEnv: @lru_cache(None) def get_implications( - self, e: "sympy.Expr" - ) -> Tuple[Tuple["sympy.Expr", "sympy.logic.boolalg.BooleanAtom"]]: + self, e: SympyBoolean + ) -> Tuple[Tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: """Given a expression, it returns a list of predicates that follow from it""" - equiv = {} + equiv: Dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} def add_expr(expr): expr = canonicalize_bool_expr(expr) @@ -5105,7 +5183,7 @@ class ShapeEnv: elif isinstance(e, sympy.Lt): add_expr(sympy.Le(e.lhs, e.rhs)) add_expr(sympy.Ne(e.lhs, e.rhs)) - if e.lhs.is_integer and e.rhs.is_integer: + if e.lhs.is_integer and e.rhs.is_integer: # type: ignore[attr-defined] add_expr(sympy.Le(e.lhs, e.rhs - 1)) elif isinstance(e, sympy.Le): add_expr(sympy.Lt(e.lhs, e.rhs + 1)) @@ -5114,14 +5192,14 @@ class ShapeEnv: @_lru_cache def _maybe_evaluate_static( self, - expr: "sympy.Expr", + expr: sympy.Basic, *, unbacked_only: bool = False, compute_hint: bool = False, size_oblivious: bool = False, - axioms: Optional[Tuple[sympy.Expr]] = None, + axioms: Optional[Tuple[SympyBoolean]] = None, var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None, - ) -> "Optional[sympy.Expr]": + ) -> Optional[sympy.Basic]: """ Tries to evaluate expr without introducing guards @@ -5177,11 +5255,11 @@ class ShapeEnv: return r @_lru_cache - def replace(self, expr: "sympy.Expr") -> "sympy.Expr": + def replace(self, expr: sympy.Expr) -> sympy.Expr: """Apply symbol replacements to any symbols in the given expression""" replacements = {} for s in expr.free_symbols: - r = self._find(cast(sympy.Symbol, s)) + r = self._find(s) # Micro-optimization: only do replacements if r and s are different # Otherwise, xreplace is not a no-op and will trigger expensive # assumption queries if expr has a relational node. @@ -5204,7 +5282,7 @@ class ShapeEnv: self._update_version_counter() @_lru_cache - def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": + def simplify(self, expr: sympy.Expr) -> sympy.Expr: """Use known constraints and replacements to simplify the given expr""" expr = self.replace(expr) # TODO it would seem that this pass is not necessary given the @@ -5249,8 +5327,11 @@ class ShapeEnv: expr = new_expr return expr + # TODO: overload for allow_none literal @lru_cache(256) - def size_hint(self, expr: "sympy.Expr", *, allow_none=False): + def size_hint( + self, expr: sympy.Basic, *, allow_none: bool = False + ) -> Optional[sympy.Basic]: """ Gets a size hint for a given expression from the underlying shapes we had. Does not introduce a guard, so only use this when you can guarantee that @@ -5295,7 +5376,7 @@ class ShapeEnv: # NB: keep in sync with size_hint @lru_cache(256) - def has_hint(self, expr: "sympy.Expr"): + def has_hint(self, expr: sympy.Expr) -> bool: result_expr = safe_expand(expr).xreplace(self.var_to_val) return ( result_expr.is_number @@ -5303,8 +5384,12 @@ class ShapeEnv: ) def _make_data_dependent_error( - self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None - ): + self, + expr: sympy.Basic, + unhinted_expr: sympy.Basic, + *, + size_oblivious_result: Optional[sympy.Basic] = None, + ) -> GuardOnDataDependentSymNode: # TODO: in a Dynamo context, having user code, and having the # name of the local, will be much better size_like_symbols = [] @@ -5322,7 +5407,7 @@ class ShapeEnv: "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" ) sloc, maybe_extra_debug = self._get_stack_summary(True) - if expr.is_integer: + if expr.is_integer: # type: ignore[attr-defined] desc = ( "Could not extract specialized integer from data-dependent expression" ) @@ -5347,12 +5432,12 @@ class ShapeEnv: def _update_var_to_range( self, - symbol, - vr, + symbol: sympy.Symbol, + vr: ValueRanges, vr_sloc: Optional[ValueRangesSLoc] = None, *, - is_constraint=False, - ): + is_constraint: bool = False, + ) -> None: lower, upper = vr.lower, vr.upper # If we have a size-like unbacked SymInt, refuse to refine the range to be @@ -5397,7 +5482,7 @@ class ShapeEnv: if not is_constraint: assert v in r, f"{v} not in {r}" - def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: + def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: """ Adds or updates a replacement for a symbol. Use this instead of `self.replacements[a] = tgt`. @@ -5558,13 +5643,13 @@ class ShapeEnv: # Z3, in case an expression uses 'a'. self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) - def _add_divisible(self, expr: "sympy.Expr"): + def _add_divisible(self, expr: sympy.Expr) -> None: self.divisible.add(expr) self._update_version_counter() @_lru_cache @record_shapeenv_event() - def _find(self, a: "sympy.Symbol") -> "sympy.Expr": + def _find(self, a: sympy.Symbol) -> sympy.Expr: """ Implements a DSU-like algorithm to find the variable that represents a Also handles transitive non-identity replacements. @@ -5582,7 +5667,7 @@ class ShapeEnv: return self.replacements[a] @lru_cache(256) - def _maybe_guard_rel(self, expr: "sympy.Rel") -> None: + def _maybe_guard_rel(self, expr: sympy.Rel) -> None: """ The relational guard is guarded to be true. Use this information to simplify shapes (i.e. a == b or a % 5 == 0) @@ -5675,9 +5760,7 @@ class ShapeEnv: new_var = self._find(r[1]) ok = len(free_unbacked_symbols(new_var)) == 0 if ok: - self._set_replacement( - cast(sympy.Symbol, free[0]), new_var, "solve" - ) + self._set_replacement(free[0], new_var, "solve") except NotImplementedError: pass if expr.has(Mod): @@ -5732,7 +5815,7 @@ class ShapeEnv: return ValueRanges(-int_oo, int_oo) @_lru_cache - def _simplify_floor_div(self, expr): + def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr: floor_divs = tuple(expr.atoms(FloorDiv)) # we expect floor_divs to be exact, # and thus add the guards for the exact floordivs, @@ -5747,7 +5830,7 @@ class ShapeEnv: # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen # and if so issue a warning - def _check_frozen(self, expr, concrete_val): + def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None: if self.frozen: self.counter["ignored_backward_guard"] += 1 signpost_event( @@ -5771,12 +5854,13 @@ class ShapeEnv: def _get_stack_summary( self, is_debug: bool = False, framework_loc: Optional[str] = None ) -> Tuple[SLoc, str]: - if framework_loc is None: + floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc + if floc is None: frame = inspect.currentframe() try: while frame is not None: if frame.f_code.co_filename not in uninteresting_files(): - framework_loc = traceback.FrameSummary( + floc = traceback.FrameSummary( frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name, @@ -5808,14 +5892,14 @@ class ShapeEnv: "\nFor C++ stack trace, run with " "TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" ) - return SLoc(framework_loc, maybe_user_loc), maybe_extra_debug + return SLoc(floc, maybe_user_loc), maybe_extra_debug # Pass in framework_loc to override the framework location info def _get_sloc(self, framework_loc: Optional[str] = None) -> SLoc: sloc, _ = self._get_stack_summary(framework_loc=framework_loc) return sloc - def _log_guard(self, prefix: str, g, forcing_spec: bool): + def _log_guard(self, prefix: str, g, forcing_spec: bool) -> None: if self.log.isEnabledFor(logging.INFO): str_g = str(g) is_debug = ( @@ -5843,13 +5927,13 @@ class ShapeEnv: @record_shapeenv_event(save_tracked_fakes=True) def evaluate_expr( self, - orig_expr: "sympy.Expr", - hint=None, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]] = None, fx_node=None, size_oblivious: bool = False, *, forcing_spec: bool = False, - ): + ) -> sympy.Basic: try: return self._evaluate_expr( orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec @@ -5866,13 +5950,13 @@ class ShapeEnv: def _evaluate_expr( self, - orig_expr: "sympy.Expr", - hint=None, - fx_node=None, + orig_expr: sympy.Basic, + hint: Optional[Union[bool, int, float]] = None, + fx_node: Optional[torch.fx.Node] = None, size_oblivious: bool = False, *, forcing_spec: bool = False, - ): + ) -> sympy.Basic: """ Given an expression, evaluates it, adding guards if necessary """ @@ -5881,12 +5965,18 @@ class ShapeEnv: # Don't track this one @functools.lru_cache(None) - def compute_concrete_val(): + def compute_concrete_val() -> sympy.Basic: if hint is None: - return self.size_hint(orig_expr) + # This is only ever called for expressions WITHOUT unbacked + # symbols + r = self.size_hint(orig_expr) + assert r is not None + return r else: return sympy.sympify(hint) + concrete_val: Optional[sympy.Basic] + # Check if: # 1. 'translation_validation' is set # 2. the corresponding 'fx_node' is not 'None' @@ -5902,6 +5992,7 @@ class ShapeEnv: and not self._suppress_guards_tls() and not size_oblivious ): + # TODO: does this even worked with unbacked :think: concrete_val = compute_concrete_val() if concrete_val is sympy.true: node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) @@ -5959,6 +6050,7 @@ class ShapeEnv: # TODO: dedupe this with _maybe_evaluate_static # Attempt to eliminate the unbacked SymInt new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None if not (new_expr.free_symbols <= self.var_to_val.keys()): size_oblivious_result = None if not size_oblivious: @@ -6015,7 +6107,7 @@ class ShapeEnv: # Turn this into a boolean expression, no longer need to consult # concrete_val if concrete_val is sympy.true: - g = expr + g = cast(SympyBoolean, expr) elif concrete_val is sympy.false: g = sympy.Not(expr) else: @@ -6023,7 +6115,7 @@ class ShapeEnv: if transmute_into_runtime_assert: self.defer_runtime_assert( - g, f"propagate_real_tensors: {orig_expr} == {unsound_result}" + g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" ) return concrete_val @@ -6080,7 +6172,7 @@ class ShapeEnv: return concrete_val - def cleanup(self): + def cleanup(self) -> None: """ Break reference cycles. @@ -6094,7 +6186,9 @@ class ShapeEnv: ra.stack.cleanup() @record_shapeenv_event(save_tracked_fakes=True) - def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): + def defer_runtime_assert( + self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None + ) -> bool: """Create an assert that is checked at runtime Args: @@ -6113,10 +6207,12 @@ class ShapeEnv: self.log.debug( "runtime_assert %s == %s [statically known]", orig_expr, static_expr ) - return static_expr + # TODO: assert bool(static_expr) + return bool(static_expr) # Attempt to eliminate the unbacked SymInt new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None if ( not self.prefer_deferred_runtime_asserts_over_guards and new_expr.free_symbols <= self.var_to_val.keys() @@ -6182,7 +6278,7 @@ class ShapeEnv: # 1. Tries to isolate a variable in the left-hand side # 2. Compute the value range of the right-hand side # 3. Update the value range of the variable, if better - def _refine_ranges(self, expr: sympy.Expr) -> None: + def _refine_ranges(self, expr: SympyBoolean) -> None: expr = self.simplify(expr) for symbol in expr.free_symbols: @@ -6248,7 +6344,7 @@ class ShapeEnv: @record_shapeenv_event() def constrain_symbol_range( self, s: sympy.Symbol, compiler_min: int, compiler_max: int - ): + ) -> None: upd_vr = ValueRanges(compiler_min, compiler_max) old_vr = self.var_to_range.get(s, ValueRanges.unknown()) self._update_var_to_range(s, upd_vr) @@ -6258,17 +6354,17 @@ class ShapeEnv: ) -def _is_int(expr): +def _is_int(expr: object) -> bool: return isinstance(expr, SymInt) and expr.node.expr.is_number # WARNING: This is legacy, DO NOT USE -def _is_dim_dynamic(t, d): +def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool: return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices class PropagateUnbackedSymInts(torch.fx.Interpreter): - def run_node(self, n: torch.fx.Node): + def run_node(self, n: torch.fx.Node) -> Result: """ Run an FX node, propagating unbacked Symbol bindings to the new fake tensor """ @@ -6279,7 +6375,7 @@ class PropagateUnbackedSymInts(torch.fx.Interpreter): return result -def _find_user_code_frame(): +def _find_user_code_frame() -> Optional[types.FrameType]: frame = inspect.currentframe() while frame is not None: if not frame.f_code.co_filename.startswith( @@ -6290,7 +6386,7 @@ def _find_user_code_frame(): return frame -def _blame_user_code(e, frame): +def _blame_user_code(e: Exception, frame: types.FrameType) -> None: frame_summary = traceback.FrameSummary( frame.f_code.co_filename, frame.f_lineno, @@ -6310,21 +6406,24 @@ class _PythonPrinter(sympy.printing.str.StrPrinter): (i.e., as ==, !=, >, <). """ - def __init__(self, src_map): + def __init__(self, src_map: Dict[str, List[str]]) -> None: super().__init__() self.src_map = src_map - def _print_Symbol(self, sym): + def _print_Symbol(self, sym: sympy.Symbol) -> str: return self.src_map[sym.name][0] - def _print_Relational(self, expr): + def _print_Relational(self, expr: sympy.core.relational.Relational) -> str: lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr)) + assert hasattr(expr, "rel_op") rel_op = expr.rel_op rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr)) return f"{lhs} {rel_op} {rhs}" -def _suggest_torch_checks(e, src_map): +def _suggest_torch_checks( + e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]] +) -> None: # extract the unresolved condition on unbacked symints in the error cond = e.cond diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) @@ -6350,7 +6449,9 @@ def _suggest_torch_checks(e, src_map): e.args = (msg,) -def _suggest_fixes_for_data_dependent_error_non_strict(e): +def _suggest_fixes_for_data_dependent_error_non_strict( + e: GuardOnDataDependentSymNode, +) -> None: """ Given a raised data-dependent error, add the following to the error message: 1. the closest user code location that raised the error; diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 01803600d021..0b337165ba94 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -346,11 +346,12 @@ def insert_deferred_runtime_asserts( # this guards against deleting calls that produce unbacked bindings we haven't yet seen. # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint # (is backed), but produces an unbacked symbol. In this case keep the node alive. + resolved_unbacked_bindings = resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings", {}) + ) + assert resolved_unbacked_bindings is not None new_unbacked_bindings = ( - resolve_unbacked_bindings( - shape_env, node.meta.get("unbacked_bindings", {}) - ).keys() - - expr_to_proxy.keys() + resolved_unbacked_bindings.keys() - expr_to_proxy.keys() ) # maybe re-reify expression, replace current node diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 70c65a690717..5460bffb809f 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -314,6 +314,17 @@ class TensorWithFlatten(Protocol): def stride(self, dim: int) -> int: ... + @overload + def size(self, dim: None = None) -> Tuple[int, ...]: + ... + + @overload + def size(self, dim: int) -> int: + ... + + def storage_offset(self) -> int: + ... + def dim(self) -> int: ... diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index e122d6cd0b5f..cdc722d20fa5 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -43,7 +43,7 @@ def try_solve( thing: sympy.Basic, trials: int = 5, floordiv_inequality: bool = True, -) -> Optional[Tuple[sympy.Rel, sympy.Basic]]: +) -> Optional[Tuple[sympy.Rel, sympy.Expr]]: mirror = mirror_rel_op(type(expr)) # Ignore unsupported expressions: