Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935
This commit is contained in:
Edward Z. Yang
2024-09-30 18:21:41 -07:00
committed by PyTorch MergeBot
parent b85f21fc1d
commit cc8f1cddd4
11 changed files with 318 additions and 189 deletions

View File

@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
class SympyPlugin(Plugin):
def get_base_class_hook(self, fullname: str):
# TODO: This apparently never worked
if fullname == "sympy.core.basic.Basic":
return add_assumptions
return None
def get_attribute_hook(self, fullname: str):
if fullname == "sympy.core.basic.Basic.free_symbols":
return lambda ctx: ctx.api.named_generic_type(
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
)
return None
def add_assumptions(ctx) -> None:
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)

View File

@ -660,7 +660,7 @@ class OutputGraph:
assert arg.fake_tensor is not None
def bind_symint(s, prop):
def bind_symint(s: torch.SymInt, prop):
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
return
s0 = s.node.expr
@ -677,6 +677,7 @@ class OutputGraph:
source=prop,
)
set_example_value(proxy.node, s)
assert isinstance(s, torch.SymInt)
proxy.node.meta["grapharg"] = GraphArg(
prop,
s,

View File

@ -154,7 +154,7 @@ class GuardBuilderBase:
@dataclasses.dataclass(frozen=True)
class SLoc:
framework_loc: Union[traceback.FrameSummary, str]
framework_loc: Optional[Union[traceback.FrameSummary, str]]
maybe_user_loc: Optional[str]
def __str__(self):
@ -170,7 +170,7 @@ class SLoc:
class ShapeGuard(NamedTuple):
expr: sympy.Expr
expr: sympy.logic.boolalg.Boolean
sloc: SLoc

View File

@ -2044,7 +2044,9 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(x, int):
return x
val = V.graph._shape_env._maybe_evaluate_static(x)
return int(val)
if val is None:
return val
return int(val) # type: ignore[call-overload]
except Exception:
return None

View File

@ -45,6 +45,7 @@ from torch.fx.experimental.symbolic_shapes import (
resolve_unbacked_bindings,
RuntimeAssert,
ShapeEnv,
SympyBoolean,
SymTypes,
)
from torch.fx.graph import Graph
@ -352,7 +353,7 @@ class GraphLowering(torch.fx.Interpreter):
shape_env.freeze_runtime_asserts()
# We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: Dict[
sympy.Symbol, List[RuntimeAssert]
Optional[sympy.Symbol], List[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy()
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
self.sizevars = SizeVarAllocator(shape_env)
@ -1595,7 +1596,7 @@ class GraphLowering(torch.fx.Interpreter):
# This is all doable, it just hasn't been done yet.
shape_env = V.graph.sizevars.shape_env
def make_assert(expr: Expr, msg: str) -> None:
def make_assert(expr: SympyBoolean, msg: str) -> None:
assert_op = ir.AssertScalar(expr, msg)
self.register_buffer(assert_op, set_name=True)
self.register_operation(assert_op)
@ -1634,6 +1635,7 @@ class GraphLowering(torch.fx.Interpreter):
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
)
assert unbacked_bindings is not None
# When we do lowering, it is possible we reallocate unbacked SymInts.
# So we need to line up the unbacked SymInts when performing the test
# here

View File

@ -5913,9 +5913,11 @@ class FallbackKernel(ExternKernelAlloc):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
return resolve_unbacked_bindings(
resolved = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, unbacked_bindings
).keys()
)
assert resolved is not None
return resolved.keys() # type: ignore[return-value]
else:
return OrderedSet()

View File

@ -2676,6 +2676,7 @@ def _local_scalar_dense(data):
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
)
assert unbacked_bindings is not None
assert len(unbacked_bindings) == 1, unbacked_bindings
# NB: Have to be very careful here. V.graph.current_node.meta["val"]
# seemingly also contains a symbol which you want to do binding for,

File diff suppressed because it is too large Load Diff

View File

@ -346,11 +346,12 @@ def insert_deferred_runtime_asserts(
# this guards against deleting calls that produce unbacked bindings we haven't yet seen.
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
# (is backed), but produces an unbacked symbol. In this case keep the node alive.
resolved_unbacked_bindings = resolve_unbacked_bindings(
shape_env, node.meta.get("unbacked_bindings", {})
)
assert resolved_unbacked_bindings is not None
new_unbacked_bindings = (
resolve_unbacked_bindings(
shape_env, node.meta.get("unbacked_bindings", {})
).keys()
- expr_to_proxy.keys()
resolved_unbacked_bindings.keys() - expr_to_proxy.keys()
)
# maybe re-reify expression, replace current node

View File

@ -314,6 +314,17 @@ class TensorWithFlatten(Protocol):
def stride(self, dim: int) -> int:
...
@overload
def size(self, dim: None = None) -> Tuple[int, ...]:
...
@overload
def size(self, dim: int) -> int:
...
def storage_offset(self) -> int:
...
def dim(self) -> int:
...

View File

@ -43,7 +43,7 @@ def try_solve(
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
) -> Optional[Tuple[sympy.Rel, sympy.Expr]]:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions: