mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Turn on type-checking in torch.fx.experimental.symbolic_shapes (#136972)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972 Approved by: https://github.com/Skylion007 ghstack dependencies: #136934, #136935
This commit is contained in:
committed by
PyTorch MergeBot
parent
b85f21fc1d
commit
cc8f1cddd4
@ -5,10 +5,18 @@ from mypy.types import NoneType, UnionType
|
||||
|
||||
class SympyPlugin(Plugin):
|
||||
def get_base_class_hook(self, fullname: str):
|
||||
# TODO: This apparently never worked
|
||||
if fullname == "sympy.core.basic.Basic":
|
||||
return add_assumptions
|
||||
return None
|
||||
|
||||
def get_attribute_hook(self, fullname: str):
|
||||
if fullname == "sympy.core.basic.Basic.free_symbols":
|
||||
return lambda ctx: ctx.api.named_generic_type(
|
||||
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def add_assumptions(ctx) -> None:
|
||||
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
||||
|
Reference in New Issue
Block a user