mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972 Approved by: https://github.com/Skylion007 ghstack dependencies: #136934, #136935
This commit is contained in:
committed by
PyTorch MergeBot
parent
b85f21fc1d
commit
cc8f1cddd4
@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
|
|||||||
|
|
||||||
class SympyPlugin(Plugin):
|
class SympyPlugin(Plugin):
|
||||||
def get_base_class_hook(self, fullname: str):
|
def get_base_class_hook(self, fullname: str):
|
||||||
|
# TODO: This apparently never worked
|
||||||
if fullname == "sympy.core.basic.Basic":
|
if fullname == "sympy.core.basic.Basic":
|
||||||
return add_assumptions
|
return add_assumptions
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_attribute_hook(self, fullname: str):
|
||||||
|
if fullname == "sympy.core.basic.Basic.free_symbols":
|
||||||
|
return lambda ctx: ctx.api.named_generic_type(
|
||||||
|
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def add_assumptions(ctx) -> None:
|
def add_assumptions(ctx) -> None:
|
||||||
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
||||||
|
@ -660,7 +660,7 @@ class OutputGraph:
|
|||||||
|
|
||||||
assert arg.fake_tensor is not None
|
assert arg.fake_tensor is not None
|
||||||
|
|
||||||
def bind_symint(s, prop):
|
def bind_symint(s: torch.SymInt, prop):
|
||||||
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
|
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
|
||||||
return
|
return
|
||||||
s0 = s.node.expr
|
s0 = s.node.expr
|
||||||
@ -677,6 +677,7 @@ class OutputGraph:
|
|||||||
source=prop,
|
source=prop,
|
||||||
)
|
)
|
||||||
set_example_value(proxy.node, s)
|
set_example_value(proxy.node, s)
|
||||||
|
assert isinstance(s, torch.SymInt)
|
||||||
proxy.node.meta["grapharg"] = GraphArg(
|
proxy.node.meta["grapharg"] = GraphArg(
|
||||||
prop,
|
prop,
|
||||||
s,
|
s,
|
||||||
|
@ -154,7 +154,7 @@ class GuardBuilderBase:
|
|||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class SLoc:
|
class SLoc:
|
||||||
framework_loc: Union[traceback.FrameSummary, str]
|
framework_loc: Optional[Union[traceback.FrameSummary, str]]
|
||||||
maybe_user_loc: Optional[str]
|
maybe_user_loc: Optional[str]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -170,7 +170,7 @@ class SLoc:
|
|||||||
|
|
||||||
|
|
||||||
class ShapeGuard(NamedTuple):
|
class ShapeGuard(NamedTuple):
|
||||||
expr: sympy.Expr
|
expr: sympy.logic.boolalg.Boolean
|
||||||
sloc: SLoc
|
sloc: SLoc
|
||||||
|
|
||||||
|
|
||||||
|
@ -2044,7 +2044,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
return x
|
return x
|
||||||
val = V.graph._shape_env._maybe_evaluate_static(x)
|
val = V.graph._shape_env._maybe_evaluate_static(x)
|
||||||
return int(val)
|
if val is None:
|
||||||
|
return val
|
||||||
|
return int(val) # type: ignore[call-overload]
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||||||
resolve_unbacked_bindings,
|
resolve_unbacked_bindings,
|
||||||
RuntimeAssert,
|
RuntimeAssert,
|
||||||
ShapeEnv,
|
ShapeEnv,
|
||||||
|
SympyBoolean,
|
||||||
SymTypes,
|
SymTypes,
|
||||||
)
|
)
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
@ -352,7 +353,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
shape_env.freeze_runtime_asserts()
|
shape_env.freeze_runtime_asserts()
|
||||||
# We're going to mutate ras_by_symbol as we finish generating them
|
# We're going to mutate ras_by_symbol as we finish generating them
|
||||||
self.ras_by_symbol: Dict[
|
self.ras_by_symbol: Dict[
|
||||||
sympy.Symbol, List[RuntimeAssert]
|
Optional[sympy.Symbol], List[RuntimeAssert]
|
||||||
] = shape_env.deferred_runtime_asserts.copy()
|
] = shape_env.deferred_runtime_asserts.copy()
|
||||||
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
self.bound_unbacked_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||||
self.sizevars = SizeVarAllocator(shape_env)
|
self.sizevars = SizeVarAllocator(shape_env)
|
||||||
@ -1595,7 +1596,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
# This is all doable, it just hasn't been done yet.
|
# This is all doable, it just hasn't been done yet.
|
||||||
shape_env = V.graph.sizevars.shape_env
|
shape_env = V.graph.sizevars.shape_env
|
||||||
|
|
||||||
def make_assert(expr: Expr, msg: str) -> None:
|
def make_assert(expr: SympyBoolean, msg: str) -> None:
|
||||||
assert_op = ir.AssertScalar(expr, msg)
|
assert_op = ir.AssertScalar(expr, msg)
|
||||||
self.register_buffer(assert_op, set_name=True)
|
self.register_buffer(assert_op, set_name=True)
|
||||||
self.register_operation(assert_op)
|
self.register_operation(assert_op)
|
||||||
@ -1634,6 +1635,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
unbacked_bindings = resolve_unbacked_bindings(
|
unbacked_bindings = resolve_unbacked_bindings(
|
||||||
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
|
V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
|
||||||
)
|
)
|
||||||
|
assert unbacked_bindings is not None
|
||||||
# When we do lowering, it is possible we reallocate unbacked SymInts.
|
# When we do lowering, it is possible we reallocate unbacked SymInts.
|
||||||
# So we need to line up the unbacked SymInts when performing the test
|
# So we need to line up the unbacked SymInts when performing the test
|
||||||
# here
|
# here
|
||||||
|
@ -5913,9 +5913,11 @@ class FallbackKernel(ExternKernelAlloc):
|
|||||||
|
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
if unbacked_bindings := getattr(self, "unbacked_bindings", None):
|
||||||
return resolve_unbacked_bindings(
|
resolved = resolve_unbacked_bindings(
|
||||||
V.graph.sizevars.shape_env, unbacked_bindings
|
V.graph.sizevars.shape_env, unbacked_bindings
|
||||||
).keys()
|
)
|
||||||
|
assert resolved is not None
|
||||||
|
return resolved.keys() # type: ignore[return-value]
|
||||||
else:
|
else:
|
||||||
return OrderedSet()
|
return OrderedSet()
|
||||||
|
|
||||||
|
@ -2676,6 +2676,7 @@ def _local_scalar_dense(data):
|
|||||||
unbacked_bindings = resolve_unbacked_bindings(
|
unbacked_bindings = resolve_unbacked_bindings(
|
||||||
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
||||||
)
|
)
|
||||||
|
assert unbacked_bindings is not None
|
||||||
assert len(unbacked_bindings) == 1, unbacked_bindings
|
assert len(unbacked_bindings) == 1, unbacked_bindings
|
||||||
# NB: Have to be very careful here. V.graph.current_node.meta["val"]
|
# NB: Have to be very careful here. V.graph.current_node.meta["val"]
|
||||||
# seemingly also contains a symbol which you want to do binding for,
|
# seemingly also contains a symbol which you want to do binding for,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -346,11 +346,12 @@ def insert_deferred_runtime_asserts(
|
|||||||
# this guards against deleting calls that produce unbacked bindings we haven't yet seen.
|
# this guards against deleting calls that produce unbacked bindings we haven't yet seen.
|
||||||
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
|
# in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
|
||||||
# (is backed), but produces an unbacked symbol. In this case keep the node alive.
|
# (is backed), but produces an unbacked symbol. In this case keep the node alive.
|
||||||
|
resolved_unbacked_bindings = resolve_unbacked_bindings(
|
||||||
|
shape_env, node.meta.get("unbacked_bindings", {})
|
||||||
|
)
|
||||||
|
assert resolved_unbacked_bindings is not None
|
||||||
new_unbacked_bindings = (
|
new_unbacked_bindings = (
|
||||||
resolve_unbacked_bindings(
|
resolved_unbacked_bindings.keys() - expr_to_proxy.keys()
|
||||||
shape_env, node.meta.get("unbacked_bindings", {})
|
|
||||||
).keys()
|
|
||||||
- expr_to_proxy.keys()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# maybe re-reify expression, replace current node
|
# maybe re-reify expression, replace current node
|
||||||
|
@ -314,6 +314,17 @@ class TensorWithFlatten(Protocol):
|
|||||||
def stride(self, dim: int) -> int:
|
def stride(self, dim: int) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def size(self, dim: None = None) -> Tuple[int, ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def size(self, dim: int) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
def storage_offset(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
def dim(self) -> int:
|
def dim(self) -> int:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def try_solve(
|
|||||||
thing: sympy.Basic,
|
thing: sympy.Basic,
|
||||||
trials: int = 5,
|
trials: int = 5,
|
||||||
floordiv_inequality: bool = True,
|
floordiv_inequality: bool = True,
|
||||||
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
|
) -> Optional[Tuple[sympy.Rel, sympy.Expr]]:
|
||||||
mirror = mirror_rel_op(type(expr))
|
mirror = mirror_rel_op(type(expr))
|
||||||
|
|
||||||
# Ignore unsupported expressions:
|
# Ignore unsupported expressions:
|
||||||
|
Reference in New Issue
Block a user