mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/136972 Approved by: https://github.com/Skylion007 ghstack dependencies: #136917, #136934, #136935
68 lines
1.7 KiB
Python
68 lines
1.7 KiB
Python
from mypy.plugin import Plugin
|
|
from mypy.plugins.common import add_attribute_to_class
|
|
from mypy.types import NoneType, UnionType
|
|
|
|
|
|
class SympyPlugin(Plugin):
|
|
def get_base_class_hook(self, fullname: str):
|
|
# TODO: This apparently never worked
|
|
if fullname == "sympy.core.basic.Basic":
|
|
return add_assumptions
|
|
return None
|
|
|
|
def get_attribute_hook(self, fullname: str):
|
|
if fullname == "sympy.core.basic.Basic.free_symbols":
|
|
return lambda ctx: ctx.api.named_generic_type(
|
|
"builtins.set", [ctx.api.named_type("sympy.Symbol")]
|
|
)
|
|
return None
|
|
|
|
|
|
def add_assumptions(ctx) -> None:
|
|
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
|
# (do not import sympy to speedup mypy plugin load time)
|
|
assumptions = [
|
|
"hermitian",
|
|
"prime",
|
|
"noninteger",
|
|
"negative",
|
|
"antihermitian",
|
|
"infinite",
|
|
"finite",
|
|
"irrational",
|
|
"extended_positive",
|
|
"nonpositive",
|
|
"odd",
|
|
"algebraic",
|
|
"integer",
|
|
"rational",
|
|
"extended_real",
|
|
"nonnegative",
|
|
"transcendental",
|
|
"extended_nonzero",
|
|
"extended_negative",
|
|
"composite",
|
|
"complex",
|
|
"imaginary",
|
|
"nonzero",
|
|
"zero",
|
|
"even",
|
|
"positive",
|
|
"polar",
|
|
"extended_nonpositive",
|
|
"extended_nonnegative",
|
|
"real",
|
|
"commutative",
|
|
]
|
|
for a in assumptions:
|
|
add_attribute_to_class(
|
|
ctx.api,
|
|
ctx.cls,
|
|
f"is_{a}",
|
|
UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
|
|
)
|
|
|
|
|
|
def plugin(version: str):
|
|
return SympyPlugin
|