More progress on type checking ValueRanges (#118870)

Type checking Python is a pain. Here are my learnings:

* The types for heavily polymorphic code is going to be verbose, no way around it. I originally was hoping I could lean on polymorphism with a bounded TypeVar to compactly write signatures for many of the ValueRanges methods, but I ran into some unworkaroundable mypy bugs. Writing out all the types explicitly and using `@overload` liberally works pretty well, so I think I recommend people do that instead of trying to do fancy things.
* Sympy is missing annotations for assumptions, because they are all metaprogrammed. I don't really relish maintaining a typeshed for sympy, so I wrote a small mypy plugin to add them in.
* GADT style refinement is... just not a good idea in practice. Mypy easily gets confused whether or not a return value from a refined section is allowed for the outer return type. So many of these have been replaced with less informative implementation types and more informative external types via overloads. Hopefully this is good for use sites.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118870
Approved by: https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2024-02-05 09:22:07 -08:00
committed by PyTorch MergeBot
parent b92819a039
commit b816760a2f
4 changed files with 171 additions and 42 deletions

View File

@ -0,0 +1,59 @@
from mypy.plugin import Plugin
from mypy.plugins.common import add_attribute_to_class
from mypy.types import NoneType, UnionType
class SympyPlugin(Plugin):
def get_base_class_hook(self, fullname: str):
if fullname == "sympy.core.basic.Basic":
return add_assumptions
return None
def add_assumptions(ctx) -> None:
# Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
# (do not import sympy to speedup mypy plugin load time)
assumptions = [
"hermitian",
"prime",
"noninteger",
"negative",
"antihermitian",
"infinite",
"finite",
"irrational",
"extended_positive",
"nonpositive",
"odd",
"algebraic",
"integer",
"rational",
"extended_real",
"nonnegative",
"transcendental",
"extended_nonzero",
"extended_negative",
"composite",
"complex",
"imaginary",
"nonzero",
"zero",
"even",
"positive",
"polar",
"extended_nonpositive",
"extended_nonnegative",
"real",
"commutative",
]
for a in assumptions:
add_attribute_to_class(
ctx.api,
ctx.cls,
f"is_{a}",
UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
)
def plugin(version: str):
return SympyPlugin