mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
120 lines
3.4 KiB
Python
120 lines
3.4 KiB
Python
"""
|
|
Python polyfills for operator
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import operator
|
|
from typing import Any, Callable, overload, TYPE_CHECKING, TypeVar
|
|
from typing_extensions import TypeVarTuple, Unpack
|
|
|
|
from ..decorators import substitute_in_graph
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable
|
|
|
|
|
|
# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`)
|
|
__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"]
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_T1 = TypeVar("_T1")
|
|
_T2 = TypeVar("_T2")
|
|
_Ts = TypeVarTuple("_Ts")
|
|
_U = TypeVar("_U")
|
|
_U1 = TypeVar("_U1")
|
|
_U2 = TypeVar("_U2")
|
|
_Us = TypeVarTuple("_Us")
|
|
|
|
|
|
@overload
|
|
# pyrefly: ignore # inconsistent-overload
|
|
def attrgetter(attr: str, /) -> Callable[[Any], _U]: ...
|
|
|
|
|
|
@overload
|
|
# pyrefly: ignore # inconsistent-overload
|
|
def attrgetter(
|
|
attr1: str, attr2: str, /, *attrs: str
|
|
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
|
|
|
|
|
# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter
|
|
@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
|
|
def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]:
|
|
if len(attrs) == 0:
|
|
raise TypeError("attrgetter expected 1 argument, got 0")
|
|
|
|
if any(not isinstance(attr, str) for attr in attrs):
|
|
raise TypeError("attribute name must be a string")
|
|
|
|
def resolve_attr(obj: Any, attr: str) -> Any:
|
|
for name in attr.split("."):
|
|
obj = getattr(obj, name)
|
|
return obj
|
|
|
|
if len(attrs) == 1:
|
|
attr = attrs[0]
|
|
|
|
def getter(obj: Any) -> Any:
|
|
return resolve_attr(obj, attr)
|
|
|
|
else:
|
|
|
|
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
|
|
return tuple(resolve_attr(obj, attr) for attr in attrs)
|
|
|
|
return getter
|
|
|
|
|
|
@overload
|
|
# pyrefly: ignore # inconsistent-overload
|
|
def itemgetter(item: _T, /) -> Callable[[Any], _U]: ...
|
|
|
|
|
|
@overload
|
|
# pyrefly: ignore # inconsistent-overload
|
|
def itemgetter(
|
|
item1: _T1, item2: _T2, /, *items: Unpack[_Ts]
|
|
) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ...
|
|
|
|
|
|
# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter
|
|
@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc]
|
|
def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]:
|
|
if len(items) == 0:
|
|
raise TypeError("itemgetter expected 1 argument, got 0")
|
|
|
|
if len(items) == 1:
|
|
item = items[0]
|
|
|
|
def getter(obj: Any) -> Any:
|
|
return obj[item]
|
|
|
|
else:
|
|
|
|
def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc]
|
|
return tuple(obj[item] for item in items)
|
|
|
|
return getter
|
|
|
|
|
|
# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller
|
|
@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type]
|
|
def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]:
|
|
if not isinstance(name, str):
|
|
raise TypeError("method name must be a string")
|
|
|
|
def caller(obj: Any) -> Any:
|
|
return getattr(obj, name)(*args, **kwargs)
|
|
|
|
return caller
|
|
|
|
|
|
# Reference: https://docs.python.org/3/library/operator.html#operator.countOf
|
|
@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc]
|
|
def countOf(a: Iterable[_T], b: _T, /) -> int:
|
|
return sum(it is b or it == b for it in a)
|