mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More progress on type checking ValueRanges (#118870)
Type checking Python is a pain. Here are my learnings: * The types for heavily polymorphic code is going to be verbose, no way around it. I originally was hoping I could lean on polymorphism with a bounded TypeVar to compactly write signatures for many of the ValueRanges methods, but I ran into some unworkaroundable mypy bugs. Writing out all the types explicitly and using `@overload` liberally works pretty well, so I think I recommend people do that instead of trying to do fancy things. * Sympy is missing annotations for assumptions, because they are all metaprogrammed. I don't really relish maintaining a typeshed for sympy, so I wrote a small mypy plugin to add them in. * GADT style refinement is... just not a good idea in practice. Mypy easily gets confused whether or not a return value from a refined section is allowed for the outer return type. So many of these have been replaced with less informative implementation types and more informative external types via overloads. Hopefully this is good for use sites. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118870 Approved by: https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b92819a039
commit
b816760a2f
59
mypy_plugins/sympy_mypy_plugin.py
Normal file
59
mypy_plugins/sympy_mypy_plugin.py
Normal file
@ -0,0 +1,59 @@
|
||||
from mypy.plugin import Plugin
|
||||
from mypy.plugins.common import add_attribute_to_class
|
||||
from mypy.types import NoneType, UnionType
|
||||
|
||||
|
||||
class SympyPlugin(Plugin):
|
||||
def get_base_class_hook(self, fullname: str):
|
||||
if fullname == "sympy.core.basic.Basic":
|
||||
return add_assumptions
|
||||
return None
|
||||
|
||||
|
||||
def add_assumptions(ctx) -> None:
|
||||
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
|
||||
# (do not import sympy to speedup mypy plugin load time)
|
||||
assumptions = [
|
||||
"hermitian",
|
||||
"prime",
|
||||
"noninteger",
|
||||
"negative",
|
||||
"antihermitian",
|
||||
"infinite",
|
||||
"finite",
|
||||
"irrational",
|
||||
"extended_positive",
|
||||
"nonpositive",
|
||||
"odd",
|
||||
"algebraic",
|
||||
"integer",
|
||||
"rational",
|
||||
"extended_real",
|
||||
"nonnegative",
|
||||
"transcendental",
|
||||
"extended_nonzero",
|
||||
"extended_negative",
|
||||
"composite",
|
||||
"complex",
|
||||
"imaginary",
|
||||
"nonzero",
|
||||
"zero",
|
||||
"even",
|
||||
"positive",
|
||||
"polar",
|
||||
"extended_nonpositive",
|
||||
"extended_nonnegative",
|
||||
"real",
|
||||
"commutative",
|
||||
]
|
||||
for a in assumptions:
|
||||
add_attribute_to_class(
|
||||
ctx.api,
|
||||
ctx.cls,
|
||||
f"is_{a}",
|
||||
UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
|
||||
)
|
||||
|
||||
|
||||
def plugin(version: str):
|
||||
return SympyPlugin
|
Reference in New Issue
Block a user