Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935
This commit is contained in:
Edward Z. Yang
2024-09-30 18:21:41 -07:00
committed by PyTorch MergeBot
parent b85f21fc1d
commit cc8f1cddd4
11 changed files with 318 additions and 189 deletions

View File

@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
class SympyPlugin(Plugin): class SympyPlugin(Plugin):
def get_base_class_hook(self, fullname: str): def get_base_class_hook(self, fullname: str):
# TODO: This apparently never worked
if fullname == "sympy.core.basic.Basic": if fullname == "sympy.core.basic.Basic":
return add_assumptions return add_assumptions
return None return None
def get_attribute_hook(self, fullname: str):
if fullname == "sympy.core.basic.Basic.free_symbols":
return lambda ctx: ctx.api.named_generic_type(
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
)
return None
def add_assumptions(ctx) -> None: def add_assumptions(ctx) -> None:
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined) # Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)

View File

@ -660,7 +660,7 @@ class OutputGraph:
assert arg.fake_tensor is not None assert arg.fake_tensor is not None
def bind_symint(s, prop): def bind_symint(s: torch.SymInt, prop):
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
return return
s0 = s.node.expr s0 = s.node.expr
@ -677,6 +677,7 @@ class OutputGraph:
source=prop, source=prop,
) )
set_example_value(proxy.node, s) set_example_value(proxy.node, s)
assert isinstance(s, torch.SymInt)
proxy.node.meta["grapharg"] = GraphArg( proxy.node.meta["grapharg"] = GraphArg(
prop, prop,
s, s,

View File

@ -154,7 +154,7 @@ class GuardBuilderBase:
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class SLoc: class SLoc:
framework_loc: Union[traceback.FrameSummary, str] framework_loc: Optional[Union[traceback.FrameSummary, str]]
maybe_user_loc: Optional[str] maybe_user_loc: Optional[str]
def __str__(self): def __str__(self):
@ -170,7 +170,7 @@ class SLoc:
class ShapeGuard(NamedTuple): class ShapeGuard(NamedTuple):
expr: sympy.Expr expr: sympy.logic.boolalg.Boolean
sloc: SLoc sloc: SLoc

View File

@ -2044,7 +2044,9 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(x, int): if isinstance(x, int):
return x return x
val = V.graph._shape_env._maybe_evaluate_static(x) val = V.graph._shape_env._maybe_evaluate_static(x)
return int(val) if val is None:
return val
return int(val) # type: ignore[call-overload]
except Exception: except Exception:
return None return None

View File

@ -45,6 +45,7 @@ from torch.fx.experimental.symbolic_shapes import (
resolve_unbacked_bindings, resolve_unbacked_bindings,
RuntimeAssert, RuntimeAssert,
ShapeEnv, ShapeEnv,
SympyBoolean,
SymTypes, SymTypes,
) )
from torch.fx.graph import Graph from torch.fx.graph import Graph
@ -352,7 +353,7 @@ class GraphLowering(torch.fx.Interpreter):
shape_env.freeze_runtime_asserts() shape_env.freeze_runtime_asserts()
# We're going to mutate ras_by_symbol as we finish generating them # We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: Dict[ self.ras_by_symbol: Dict[
sympy.Symbol, List[RuntimeAssert] Optional[sympy.Symbol], List[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy() ] = shape_env.deferred_runtime_asserts.copy()
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet() self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
self.sizevars = SizeVarAllocator(shape_env) self.sizevars = SizeVarAllocator(shape_env)
@ -1595,7 +1596,7 @@ class GraphLowering(torch.fx.Interpreter):
# This is all doable, it just hasn't been done yet. # This is all doable, it just hasn't been done yet.
shape_env = V.graph.sizevars.shape_env shape_env = V.graph.sizevars.shape_env
def make_assert(expr: Expr, msg: str) -> None: def make_assert(expr: SympyBoolean, msg: str) -> None:
assert_op = ir.AssertScalar(expr, msg) assert_op = ir.AssertScalar(expr, msg)
self.register_buffer(assert_op, set_name=True) self.register_buffer(assert_op, set_name=True)
self.register_operation(assert_op) self.register_operation(assert_op)
@ -1634,6 +1635,7 @@ class GraphLowering(torch.fx.Interpreter):
unbacked_bindings = resolve_unbacked_bindings( unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
) )
assert unbacked_bindings is not None
# When we do lowering, it is possible we reallocate unbacked SymInts. # When we do lowering, it is possible we reallocate unbacked SymInts.
# So we need to line up the unbacked SymInts when performing the test # So we need to line up the unbacked SymInts when performing the test
# here # here

View File

@ -5913,9 +5913,11 @@ class FallbackKernel(ExternKernelAlloc):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None): if unbacked_bindings := getattr(self, "unbacked_bindings", None):
return resolve_unbacked_bindings( resolved = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings V.graph.sizevars.shape_env, unbacked_bindings
).keys() )
assert resolved is not None
return resolved.keys() # type: ignore[return-value]
else: else:
return OrderedSet() return OrderedSet()

View File

@ -2676,6 +2676,7 @@ def _local_scalar_dense(data):
unbacked_bindings = resolve_unbacked_bindings( unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
) )
assert unbacked_bindings is not None
assert len(unbacked_bindings) == 1, unbacked_bindings assert len(unbacked_bindings) == 1, unbacked_bindings
# NB: Have to be very careful here. V.graph.current_node.meta["val"] # NB: Have to be very careful here. V.graph.current_node.meta["val"]
# seemingly also contains a symbol which you want to do binding for, # seemingly also contains a symbol which you want to do binding for,

File diff suppressed because it is too large Load Diff

View File

@ -346,11 +346,12 @@ def insert_deferred_runtime_asserts(
# this guards against deleting calls that produce unbacked bindings we haven't yet seen. # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
# (is backed), but produces an unbacked symbol. In this case keep the node alive. # (is backed), but produces an unbacked symbol. In this case keep the node alive.
resolved_unbacked_bindings = resolve_unbacked_bindings(
shape_env, node.meta.get("unbacked_bindings", {})
)
assert resolved_unbacked_bindings is not None
new_unbacked_bindings = ( new_unbacked_bindings = (
resolve_unbacked_bindings( resolved_unbacked_bindings.keys() - expr_to_proxy.keys()
shape_env, node.meta.get("unbacked_bindings", {})
).keys()
- expr_to_proxy.keys()
) )
# maybe re-reify expression, replace current node # maybe re-reify expression, replace current node

View File

@ -314,6 +314,17 @@ class TensorWithFlatten(Protocol):
def stride(self, dim: int) -> int: def stride(self, dim: int) -> int:
... ...
@overload
def size(self, dim: None = None) -> Tuple[int, ...]:
...
@overload
def size(self, dim: int) -> int:
...
def storage_offset(self) -> int:
...
def dim(self) -> int: def dim(self) -> int:
... ...

View File

@ -43,7 +43,7 @@ def try_solve(
thing: sympy.Basic, thing: sympy.Basic,
trials: int = 5, trials: int = 5,
floordiv_inequality: bool = True, floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]: ) -> Optional[Tuple[sympy.Rel, sympy.Expr]]:
mirror = mirror_rel_op(type(expr)) mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions: # Ignore unsupported expressions: