Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972
Approved by: https://github.com/Skylion007
ghstack dependencies: #136934, #136935
This commit is contained in:
Edward Z. Yang
2024-09-30 18:21:41 -07:00
committed by PyTorch MergeBot
parent b85f21fc1d
commit cc8f1cddd4
11 changed files with 318 additions and 189 deletions

View File

@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
class SympyPlugin(Plugin):
def get_base_class_hook(self, fullname: str):
# TODO: This apparently never worked
if fullname == "sympy.core.basic.Basic":
return add_assumptions
return None
def get_attribute_hook(self, fullname: str):
if fullname == "sympy.core.basic.Basic.free_symbols":
return lambda ctx: ctx.api.named_generic_type(
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
)
return None
def add_assumptions(ctx) -> None:
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)