mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972 Approved by: https://github.com/Skylion007 ghstack dependencies: #136917, #136934, #136935
This commit is contained in:
committed by
PyTorch MergeBot
parent
475a8a4e0c
commit
3ff2d93d9f
@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
|
||||
|
||||
class SympyPlugin(Plugin):
|
||||
def get_base_class_hook(self, fullname: str):
|
||||
# TODO: This apparently never worked
|
||||
if fullname == "sympy.core.basic.Basic":
|
||||
return add_assumptions
|
||||
return None
|
||||
|
||||
def get_attribute_hook(self, fullname: str):
|
||||
if fullname == "sympy.core.basic.Basic.free_symbols":
|
||||
return lambda ctx: ctx.api.named_generic_type(
|
||||
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def add_assumptions(ctx) -> None:
|
||||
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
||||
|
@ -660,7 +660,7 @@ class OutputGraph:
|
||||
|
||||
assert arg.fake_tensor is not None
|
||||
|
||||
def bind_symint(s, prop):
|
||||
def bind_symint(s: torch.SymInt, prop):
|
||||
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
|
||||
return
|
||||
s0 = s.node.expr
|
||||
@ -677,6 +677,7 @@ class OutputGraph:
|
||||
source=prop,
|
||||
)
|
||||
set_example_value(proxy.node, s)
|
||||
assert isinstance(s, torch.SymInt)
|
||||
proxy.node.meta["grapharg"] = GraphArg(
|
||||
prop,
|
||||
s,
|
||||
|
@ -154,7 +154,7 @@ class GuardBuilderBase:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SLoc:
|
||||
framework_loc: Union[traceback.FrameSummary, str]
|
||||
framework_loc: Optional[Union[traceback.FrameSummary, str]]
|
||||
maybe_user_loc: Optional[str]
|
||||
|
||||
def __str__(self):
|
||||
@ -170,7 +170,7 @@ class SLoc:
|
||||
|
||||
|
||||
class ShapeGuard(NamedTuple):
|
||||
expr: sympy.Expr
|
||||
expr: sympy.logic.boolalg.Boolean
|
||||
sloc: SLoc
|
||||
|
||||
|
||||
|
@ -2038,7 +2038,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
val = V.graph._shape_env._maybe_evaluate_static(x)
|
||||
return int(val)
|
||||
if val is None:
|
||||
return val
|
||||
return int(val) # type: ignore[call-overload]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
@ -45,6 +45,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
resolve_unbacked_bindings,
|
||||
RuntimeAssert,
|
||||
ShapeEnv,
|
||||
SympyBoolean,
|
||||
SymTypes,
|
||||
)
|
||||
from torch.fx.graph import Graph
|
||||
@ -351,7 +352,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
shape_env.freeze_runtime_asserts()
|
||||
# We're going to mutate ras_by_symbol as we finish generating them
|
||||
self.ras_by_symbol: Dict[
|
||||
sympy.Symbol, List[RuntimeAssert]
|
||||
Optional[sympy.Symbol], List[RuntimeAssert]
|
||||
] = shape_env.deferred_runtime_asserts.copy()
|
||||
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
@ -1594,7 +1595,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# This is all doable, it just hasn't been done yet.
|
||||
shape_env = V.graph.sizevars.shape_env
|
||||
|
||||
def make_assert(expr: Expr, msg: str) -> None:
|
||||
def make_assert(expr: SympyBoolean, msg: str) -> None:
|
||||
assert_op = ir.AssertScalar(expr, msg)
|
||||
self.register_buffer(assert_op, set_name=True)
|
||||
self.register_operation(assert_op)
|
||||
@ -1633,6 +1634,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
|
||||
)
|
||||
assert unbacked_bindings is not None
|
||||
# When we do lowering, it is possible we reallocate unbacked SymInts.
|
||||
# So we need to line up the unbacked SymInts when performing the test
|
||||
# here
|
||||
|
@ -5909,9 +5909,11 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||
return resolve_unbacked_bindings(
|
||||
resolved = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, unbacked_bindings
|
||||
).keys()
|
||||
)
|
||||
assert resolved is not None
|
||||
return resolved.keys() # type: ignore[return-value]
|
||||
else:
|
||||
return OrderedSet()
|
||||
|
||||
|
@ -2676,6 +2676,7 @@ def _local_scalar_dense(data):
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
||||
)
|
||||
assert unbacked_bindings is not None
|
||||
assert len(unbacked_bindings) == 1, unbacked_bindings
|
||||
# NB: Have to be very careful here. V.graph.current_node.meta["val"]
|
||||
# seemingly also contains a symbol which you want to do binding for,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -346,11 +346,12 @@ def insert_deferred_runtime_asserts(
|
||||
# this guards against deleting calls that produce unbacked bindings we haven't yet seen.
|
||||
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
|
||||
# (is backed), but produces an unbacked symbol. In this case keep the node alive.
|
||||
resolved_unbacked_bindings = resolve_unbacked_bindings(
|
||||
shape_env, node.meta.get("unbacked_bindings", {})
|
||||
)
|
||||
assert resolved_unbacked_bindings is not None
|
||||
new_unbacked_bindings = (
|
||||
resolve_unbacked_bindings(
|
||||
shape_env, node.meta.get("unbacked_bindings", {})
|
||||
).keys()
|
||||
- expr_to_proxy.keys()
|
||||
resolved_unbacked_bindings.keys() - expr_to_proxy.keys()
|
||||
)
|
||||
|
||||
# maybe re-reify expression, replace current node
|
||||
|
@ -314,6 +314,17 @@ class TensorWithFlatten(Protocol):
|
||||
def stride(self, dim: int) -> int:
|
||||
...
|
||||
|
||||
@overload
|
||||
def size(self, dim: None = None) -> Tuple[int, ...]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def size(self, dim: int) -> int:
|
||||
...
|
||||
|
||||
def storage_offset(self) -> int:
|
||||
...
|
||||
|
||||
def dim(self) -> int:
|
||||
...
|
||||
|
||||
|
@ -43,7 +43,7 @@ def try_solve(
|
||||
thing: sympy.Basic,
|
||||
trials: int = 5,
|
||||
floordiv_inequality: bool = True,
|
||||
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
|
||||
) -> Optional[Tuple[sympy.Rel, sympy.Expr]]:
|
||||
mirror = mirror_rel_op(type(expr))
|
||||
|
||||
# Ignore unsupported expressions:
|
||||
|
Reference in New Issue
Block a user