diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index 847a6d63c41a..ef4741444f94 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -31,6 +31,9 @@ python3 -m tools.pyi.gen_pyi \ --deprecated-functions-path "tools/autograd/deprecated.yaml" python3 torch/utils/data/datapipes/gen_pyi.py +# Also check generated pyi files +find torch -name '*.pyi' -exec git add --force -- "{}" + + RC=0 # Run lintrunner on all files if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then @@ -41,6 +44,9 @@ if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} RC=1 fi +# Unstage temporally added pyi files +find torch -name '*.pyi' -exec git restore --staged -- "{}" + + # Use jq to massage the JSON lint output into GitHub Actions workflow commands. jq --raw-output \ '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \ diff --git a/pyproject.toml b/pyproject.toml index 3ede84f0891a..054eb4d6ecb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -212,6 +212,11 @@ select = [ "__init__.py" = [ "F401", ] +"*.pyi" = [ + "PYI011", # typed-argument-default-in-stub + "PYI021", # docstring-in-stub + "PYI053", # string-or-bytes-too-long +] "functorch/notebooks/**" = [ "F401", ] diff --git a/test/test_mps.py b/test/test_mps.py index 3b015ea0e40c..054fc9e22550 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -113,9 +113,9 @@ class MpsMemoryLeakCheck: self.caching_allocator_before = torch.mps.current_allocated_memory() self.driver_before = torch.mps.driver_allocated_memory() - def __exit__(self, exec_type, exec_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # Don't check for leaks if an exception was thrown - if exec_type is not None: + if exc_type is not None: return # Compares caching allocator before/after statistics # An increase in allocated memory is a discrepancy indicating a possible memory leak diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 9fe361a5ace6..374f5661060e 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -1,20 +1,11 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" # mypy: allow-untyped-defs +# ruff: noqa: F401,PYI054 -import builtins +from collections.abc import Sequence from types import EllipsisType -from typing import ( - Any, - Callable, - ContextManager, - Iterator, - Literal, - NamedTuple, - overload, - Sequence, - TypeVar, -) +from typing import Any, Callable, Literal, overload, TypeVar import torch from torch import ( diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 249b6deb911e..821b977f60f5 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1,8 +1,9 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" # mypy: allow-untyped-defs +# ruff: noqa: F401 -import builtins +from collections.abc import Iterable, Iterator, Sequence from enum import Enum, IntEnum from pathlib import Path from types import EllipsisType @@ -10,17 +11,13 @@ from typing import ( Any, AnyStr, Callable, - ContextManager, Generic, IO, - Iterable, - Iterator, Literal, NamedTuple, overload, Protocol, runtime_checkable, - Sequence, SupportsIndex, TypeVar, ) @@ -71,15 +68,15 @@ from torch.utils._python_dispatch import TorchDispatchMode # This module is defined in torch/csrc/Module.cpp -K = TypeVar("K") -T = TypeVar("T") -S = TypeVar("S", bound=torch.Tensor) -P = ParamSpec("P") -ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant) -_T_co = TypeVar("_T_co", covariant=True) +K = TypeVar("K") # noqa: PYI001 +T = TypeVar("T") # noqa: PYI001 +S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001 +P = ParamSpec("P") # noqa: PYI001 +R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001 +T_co = TypeVar("T_co", covariant=True) # noqa: PYI001 @runtime_checkable -class _NestedSequence(Protocol[_T_co]): +class _NestedSequence(Protocol[T_co]): """A protocol for representing nested sequences. References:: @@ -88,10 +85,10 @@ class _NestedSequence(Protocol[_T_co]): """ def __len__(self, /) -> _int: ... - def __getitem__(self, index: _int, /) -> _T_co | _NestedSequence[_T_co]: ... + def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ... def __contains__(self, x: object, /) -> _bool: ... - def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... - def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... + def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ... + def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ... def count(self, value: Any, /) -> _int: ... def index(self, value: Any, /) -> _int: ... @@ -146,7 +143,7 @@ class Stream: def record_event(self, event: Event | None = None) -> Event: ... def __hash__(self) -> _int: ... def __eq__(self, other: object) -> _bool: ... - def __enter__(self) -> Stream: ... + def __enter__(self) -> Self: ... def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... # Defined in torch/csrc/Event.cpp @@ -321,14 +318,14 @@ def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ... def unify_type_list(types: list[JitType]) -> JitType: ... def _freeze_module( module: ScriptModule, - preserved_attrs: list[str] = [], + preserved_attrs: list[str] = ..., freeze_interfaces: _bool = True, preserveParameters: _bool = True, ) -> ScriptModule: ... def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... def _jit_pass_optimize_for_inference( module: torch.jit.ScriptModule, - other_methods: list[str] = [], + other_methods: list[str] = ..., ) -> None: ... def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... @@ -759,7 +756,7 @@ class AliasDb: ... class _InsertPoint: def __enter__(self) -> None: ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... # Defined in torch/csrc/jit/ir/ir.h class Use: @@ -1078,8 +1075,8 @@ class LiteScriptModule: def run_method(self, method_name: str, *input): ... # NOTE: switch to collections.abc.Callable in python 3.9 -class ScriptFunction(Generic[P, ReturnVal]): - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... +class ScriptFunction(Generic[P, R]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ... def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ... @property @@ -1092,9 +1089,9 @@ class ScriptFunction(Generic[P, ReturnVal]): def qualified_name(self) -> str: ... # NOTE: switch to collections.abc.Callable in python 3.9 -class ScriptMethod(Generic[P, ReturnVal]): +class ScriptMethod(Generic[P, R]): graph: Graph - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... @property def owner(self) -> ScriptModule: ... @property @@ -1375,7 +1372,7 @@ class _LinalgBackend: # members. There is a chance this is due to a recent change in the semantics # of enum membership. If so, use `member = value` to mark an enum member, # instead of `member: type` -class BatchNormBackend(Enum): ... # type: ignore[misc] +class BatchNormBackend(Enum): ... # type: ignore[misc] def _get_blas_preferred_backend() -> _BlasBackend: ... def _set_blas_preferred_backend(arg: _BlasBackend): ... @@ -1400,7 +1397,7 @@ class _ROCmFABackend: # There is a chance this is due to a recent change in the semantics of enum # membership. If so, use `member = value` to mark an enum member, instead of # `member: type` -class ConvBackend(Enum): ... # type: ignore[misc] +class ConvBackend(Enum): ... # type: ignore[misc] class Tag(Enum): ${tag_attributes} @@ -1481,9 +1478,7 @@ def _get_function_stack_at(idx: _int) -> Any: ... def _len_torch_function_stack() -> _int: ... def _set_torch_dispatch_mode(cls: Any) -> None: ... def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ... -def _pop_torch_dispatch_stack( - mode_key: _TorchDispatchModeKey | None = None, -) -> Any: ... +def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ... def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ... def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ... def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ... @@ -1494,42 +1489,42 @@ def _activate_gpu_trace() -> None: ... class _DisableTorchDispatch: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _EnableTorchFunction: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _EnablePythonDispatcher: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _DisablePythonDispatcher: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _EnablePreDispatch: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _DisableFuncTorch: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _DisableAutocast: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _InferenceMode: def __init__(self, enabled: _bool) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... def _set_autograd_fallback_mode(mode: str) -> None: ... def _get_autograd_fallback_mode() -> str: ... @@ -1783,32 +1778,32 @@ def _commit_update(a: Tensor) -> None: ... class _ExcludeDispatchKeyGuard: def __init__(self, keyset: DispatchKeySet) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _IncludeDispatchKeyGuard: def __init__(self, k: DispatchKey) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _ForceDispatchKeyGuard: def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _PreserveDispatchKeyGuard: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _AutoDispatchBelowAutograd: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... class _AutoDispatchBelowADInplaceOrView: def __init__(self) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... def _dispatch_get_registrations_for_dispatch_key( @@ -1827,18 +1822,16 @@ class _TorchDispatchModeKey(Enum): class _SetExcludeDispatchKeyGuard: def __init__(self, k: DispatchKey, enabled: _bool) -> None: ... def __enter__(self): ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... # Defined in torch/csrc/utils/schema_info.h class _SchemaInfo: def __init__(self, schema: _int) -> None: ... - @overload def is_mutable(self) -> _bool: ... @overload def is_mutable(self, name: str) -> _bool: ... - def has_argument(self, name: str) -> _bool: ... # Defined in torch/csrc/utils/init.cpp @@ -2431,7 +2424,7 @@ def _create_graph_by_tracing( strict: Any, force_outplace: Any, self: Any = None, - argument_names: list[str] = [], + argument_names: list[str] = ..., ) -> tuple[Graph, Stack]: ... def _tracer_warn_use_python(): ... def _get_tracing_state() -> TracingState: ... @@ -2458,8 +2451,6 @@ class InferredType: def success(self) -> _bool: ... def reason(self) -> str: ... -R = TypeVar("R", bound=JitType) - class Type(JitType): def str(self) -> _str: ... def containedTypes(self) -> list[JitType]: ... @@ -2558,16 +2549,18 @@ class UnionType(JitType): class ClassType(JitType): def __init__(self, qualified_name: str) -> None: ... - def qualified_name(self) ->str: ... + def qualified_name(self) -> str: ... class InterfaceType(JitType): def __init__(self, qualified_name: str) -> None: ... def getMethod(self, name: str) -> FunctionSchema | None: ... def getMethodNames(self) -> list[str]: ... -class OptionalType(JitType, Generic[R]): - def __init__(self, a: JitType) -> None: ... - def getElementType(self) -> JitType: ... +JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001 + +class OptionalType(JitType, Generic[JitTypeT]): + def __init__(self, a: JitTypeT) -> None: ... + def getElementType(self) -> JitTypeT: ... @staticmethod def ofTensor() -> OptionalType: ... @@ -2681,17 +2674,18 @@ def _fuse_to_static_module( # Defined in torch/csrc/fx/node.cpp def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ... def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ... + class _NodeBase: _erased: _bool _prev: FxNode _next: FxNode def __init__( - self, - graph: Any, - name: str, - op: str, - target: Any, - return_type: Any, + self, + graph: Any, + name: str, + op: str, + target: Any, + return_type: Any, ) -> None: ... def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... diff --git a/torch/_C/_monitor.pyi b/torch/_C/_monitor.pyi index d28c373e528b..be6f0f64f97d 100644 --- a/torch/_C/_monitor.pyi +++ b/torch/_C/_monitor.pyi @@ -3,7 +3,7 @@ import datetime from enum import Enum from types import TracebackType -from typing import Callable, Optional +from typing import Callable class Aggregation(Enum): VALUE = ... @@ -48,9 +48,9 @@ class _WaitCounterTracker: def __enter__(self) -> None: ... def __exit__( self, - exec_type: Optional[type[BaseException]] = None, - exec_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: ... class _WaitCounter: diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in index 26e043188a8d..0e7207b3ffb7 100644 --- a/torch/_C/_nn.pyi.in +++ b/torch/_C/_nn.pyi.in @@ -1,7 +1,8 @@ # ${generated_comment} # mypy: disable-error-code="type-arg" -from typing import Literal, overload, Sequence +from collections.abc import Sequence +from typing import Literal, overload from torch import memory_format, Tensor from torch.types import _bool, _device, _dtype, _int, _size diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 48b14cc4b467..5e2870f72b47 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Literal, Optional +from typing import Literal from typing_extensions import TypeAlias from torch._C import device, dtype, layout @@ -73,7 +73,7 @@ class ProfilerConfig: with_flops: bool, with_modules: bool, experimental_config: _ExperimentalConfig, - trace_id: Optional[str] = None, + trace_id: str | None = None, ) -> None: ... class _ProfilerEvent: @@ -243,4 +243,4 @@ class _RecordFunctionFast: keyword_values: dict | None = None, ) -> None: ... def __enter__(self) -> None: ... - def __exit__(self, *args: Any) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in index 5559b3c40d27..a9d7ad73479d 100644 --- a/torch/_C/return_types.pyi.in +++ b/torch/_C/return_types.pyi.in @@ -1,31 +1,11 @@ # ${generated_comment} # mypy: allow-untyped-defs -from typing import ( - Any, - Callable, - ContextManager, - Final, - Iterator, - Literal, - NamedTuple, - NoReturn, - overload, - Sequence, - TypeVar, -) +from typing import Final, NoReturn from typing_extensions import Self -from torch import ( - contiguous_format, - Generator, - inf, - memory_format, - strided, - SymInt, - Tensor, -) -from torch.types import ( +from torch import SymInt, Tensor +from torch.types import ( # noqa: F401 _bool, _device, _dtype, diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 93d126b9a4f2..53c3a3c61f03 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,7 +1,8 @@ # ${generated_comment} # mypy: allow-untyped-defs -from typing import Any, Callable, Literal, overload, Sequence +from collections.abc import Sequence +from typing import Any, Callable, Literal, overload from typing_extensions import TypeAlias from torch import Tensor diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 6bab8c6e1234..2a1846bbbf39 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2383,7 +2383,7 @@ class CudaNonDefaultStream: device_type=deviceStream.device_type) torch._C._cuda_setDevice(beforeDevice) - def __exit__(self, exec_type, exec_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # After completing CUDA test load previously active streams on all # CUDA devices. beforeDevice = torch.cuda.current_device() @@ -2431,9 +2431,9 @@ class CudaMemoryLeakCheck: driver_mem_allocated = bytes_total - bytes_free self.driver_befores.append(driver_mem_allocated) - def __exit__(self, exec_type, exec_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # Don't check for leaks if an exception was thrown - if exec_type is not None: + if exc_type is not None: return # Compares caching allocator before/after statistics diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 42ba9e53c8b4..73cfa120e494 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -5,19 +5,8 @@ # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other # classes/objects here, even though we are not injecting extra code into them at the moment. -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Literal, - Optional, - Type, - TypeVar, - Union, -) +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Literal, Optional, TypeVar, Union from torch.utils.data import Dataset, default_collate, IterableDataset from torch.utils.data.datapipes._hook_iterator import _SnapshotState @@ -27,19 +16,19 @@ _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) UNTRACABLE_DATAFRAME_PIPES: Any -class DataChunk(List[_T]): - items: List[_T] +class DataChunk(list[_T]): + items: list[_T] def __init__(self, items: Iterable[_T]) -> None: ... def as_str(self, indent: str = "") -> str: ... def __iter__(self) -> Iterator[_T]: ... def raw_iterator(self) -> Iterator[_T]: ... class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): - functions: Dict[str, Callable] = ... - reduce_ex_hook: Optional[Callable] = ... - getstate_hook: Optional[Callable] = ... - str_hook: Optional[Callable] = ... - repr_hook: Optional[Callable] = ... + functions: dict[str, Callable] = ... + reduce_ex_hook: Callable | None = ... + getstate_hook: Callable | None = ... + str_hook: Callable | None = ... + repr_hook: Callable | None = ... def __getattr__(self, attribute_name: Any): ... @classmethod def register_function(cls, function_name: Any, function: Any) -> None: ... @@ -58,7 +47,7 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): ${MapDataPipeMethods} class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): - functions: Dict[str, Callable] = ... + functions: dict[str, Callable] = ... reduce_ex_hook: Optional[Callable] = ... getstate_hook: Optional[Callable] = ... str_hook: Optional[Callable] = ...