[BE][CI][Easy] Run lintrunner on generated .pyi stub files (#150732)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150732
Approved by: https://github.com/malfet, https://github.com/cyyever, https://github.com/aorenste
This commit is contained in:
Xuehai Pan
2025-05-27 20:23:17 +08:00
committed by PyTorch MergeBot
parent 0a7eef140b
commit 7ae204c3b6
12 changed files with 95 additions and 128 deletions

View File

@ -31,6 +31,9 @@ python3 -m tools.pyi.gen_pyi \
--deprecated-functions-path "tools/autograd/deprecated.yaml"
python3 torch/utils/data/datapipes/gen_pyi.py
# Also check generated pyi files
find torch -name '*.pyi' -exec git add --force -- "{}" +
RC=0
# Run lintrunner on all files
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
@ -41,6 +44,9 @@ if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS}
RC=1
fi
# Unstage temporally added pyi files
find torch -name '*.pyi' -exec git restore --staged -- "{}" +
# Use jq to massage the JSON lint output into GitHub Actions workflow commands.
jq --raw-output \
'"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \

View File

@ -212,6 +212,11 @@ select = [
"__init__.py" = [
"F401",
]
"*.pyi" = [
"PYI011", # typed-argument-default-in-stub
"PYI021", # docstring-in-stub
"PYI053", # string-or-bytes-too-long
]
"functorch/notebooks/**" = [
"F401",
]

View File

@ -113,9 +113,9 @@ class MpsMemoryLeakCheck:
self.caching_allocator_before = torch.mps.current_allocated_memory()
self.driver_before = torch.mps.driver_allocated_memory()
def __exit__(self, exec_type, exec_value, traceback):
def __exit__(self, exc_type, exc_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
if exc_type is not None:
return
# Compares caching allocator before/after statistics
# An increase in allocated memory is a discrepancy indicating a possible memory leak

View File

@ -1,20 +1,11 @@
# ${generated_comment}
# mypy: disable-error-code="type-arg"
# mypy: allow-untyped-defs
# ruff: noqa: F401,PYI054
import builtins
from collections.abc import Sequence
from types import EllipsisType
from typing import (
Any,
Callable,
ContextManager,
Iterator,
Literal,
NamedTuple,
overload,
Sequence,
TypeVar,
)
from typing import Any, Callable, Literal, overload, TypeVar
import torch
from torch import (

View File

@ -1,8 +1,9 @@
# ${generated_comment}
# mypy: disable-error-code="type-arg"
# mypy: allow-untyped-defs
# ruff: noqa: F401
import builtins
from collections.abc import Iterable, Iterator, Sequence
from enum import Enum, IntEnum
from pathlib import Path
from types import EllipsisType
@ -10,17 +11,13 @@ from typing import (
Any,
AnyStr,
Callable,
ContextManager,
Generic,
IO,
Iterable,
Iterator,
Literal,
NamedTuple,
overload,
Protocol,
runtime_checkable,
Sequence,
SupportsIndex,
TypeVar,
)
@ -71,15 +68,15 @@ from torch.utils._python_dispatch import TorchDispatchMode
# This module is defined in torch/csrc/Module.cpp
K = TypeVar("K")
T = TypeVar("T")
S = TypeVar("S", bound=torch.Tensor)
P = ParamSpec("P")
ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant)
_T_co = TypeVar("_T_co", covariant=True)
K = TypeVar("K") # noqa: PYI001
T = TypeVar("T") # noqa: PYI001
S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001
P = ParamSpec("P") # noqa: PYI001
R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001
T_co = TypeVar("T_co", covariant=True) # noqa: PYI001
@runtime_checkable
class _NestedSequence(Protocol[_T_co]):
class _NestedSequence(Protocol[T_co]):
"""A protocol for representing nested sequences.
References::
@ -88,10 +85,10 @@ class _NestedSequence(Protocol[_T_co]):
"""
def __len__(self, /) -> _int: ...
def __getitem__(self, index: _int, /) -> _T_co | _NestedSequence[_T_co]: ...
def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ...
def __contains__(self, x: object, /) -> _bool: ...
def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ...
def count(self, value: Any, /) -> _int: ...
def index(self, value: Any, /) -> _int: ...
@ -146,7 +143,7 @@ class Stream:
def record_event(self, event: Event | None = None) -> Event: ...
def __hash__(self) -> _int: ...
def __eq__(self, other: object) -> _bool: ...
def __enter__(self) -> Stream: ...
def __enter__(self) -> Self: ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
# Defined in torch/csrc/Event.cpp
@ -321,14 +318,14 @@ def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
def unify_type_list(types: list[JitType]) -> JitType: ...
def _freeze_module(
module: ScriptModule,
preserved_attrs: list[str] = [],
preserved_attrs: list[str] = ...,
freeze_interfaces: _bool = True,
preserveParameters: _bool = True,
) -> ScriptModule: ...
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
def _jit_pass_optimize_for_inference(
module: torch.jit.ScriptModule,
other_methods: list[str] = [],
other_methods: list[str] = ...,
) -> None: ...
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
@ -759,7 +756,7 @@ class AliasDb: ...
class _InsertPoint:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
# Defined in torch/csrc/jit/ir/ir.h
class Use:
@ -1078,8 +1075,8 @@ class LiteScriptModule:
def run_method(self, method_name: str, *input): ...
# NOTE: switch to collections.abc.Callable in python 3.9
class ScriptFunction(Generic[P, ReturnVal]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
class ScriptFunction(Generic[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ...
def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ...
@property
@ -1092,9 +1089,9 @@ class ScriptFunction(Generic[P, ReturnVal]):
def qualified_name(self) -> str: ...
# NOTE: switch to collections.abc.Callable in python 3.9
class ScriptMethod(Generic[P, ReturnVal]):
class ScriptMethod(Generic[P, R]):
graph: Graph
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
@property
def owner(self) -> ScriptModule: ...
@property
@ -1375,7 +1372,7 @@ class _LinalgBackend:
# members. There is a chance this is due to a recent change in the semantics
# of enum membership. If so, use `member = value` to mark an enum member,
# instead of `member: type`
class BatchNormBackend(Enum): ... # type: ignore[misc]
class BatchNormBackend(Enum): ... # type: ignore[misc]
def _get_blas_preferred_backend() -> _BlasBackend: ...
def _set_blas_preferred_backend(arg: _BlasBackend): ...
@ -1400,7 +1397,7 @@ class _ROCmFABackend:
# There is a chance this is due to a recent change in the semantics of enum
# membership. If so, use `member = value` to mark an enum member, instead of
# `member: type`
class ConvBackend(Enum): ... # type: ignore[misc]
class ConvBackend(Enum): ... # type: ignore[misc]
class Tag(Enum):
${tag_attributes}
@ -1481,9 +1478,7 @@ def _get_function_stack_at(idx: _int) -> Any: ...
def _len_torch_function_stack() -> _int: ...
def _set_torch_dispatch_mode(cls: Any) -> None: ...
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
def _pop_torch_dispatch_stack(
mode_key: _TorchDispatchModeKey | None = None,
) -> Any: ...
def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ...
def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ...
def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ...
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
@ -1494,42 +1489,42 @@ def _activate_gpu_trace() -> None: ...
class _DisableTorchDispatch:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _EnableTorchFunction:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _EnablePythonDispatcher:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _DisablePythonDispatcher:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _EnablePreDispatch:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _DisableFuncTorch:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _DisableAutocast:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _InferenceMode:
def __init__(self, enabled: _bool) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
def _set_autograd_fallback_mode(mode: str) -> None: ...
def _get_autograd_fallback_mode() -> str: ...
@ -1783,32 +1778,32 @@ def _commit_update(a: Tensor) -> None: ...
class _ExcludeDispatchKeyGuard:
def __init__(self, keyset: DispatchKeySet) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _IncludeDispatchKeyGuard:
def __init__(self, k: DispatchKey) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _ForceDispatchKeyGuard:
def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _PreserveDispatchKeyGuard:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _AutoDispatchBelowAutograd:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
class _AutoDispatchBelowADInplaceOrView:
def __init__(self) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
def _dispatch_get_registrations_for_dispatch_key(
@ -1827,18 +1822,16 @@ class _TorchDispatchModeKey(Enum):
class _SetExcludeDispatchKeyGuard:
def __init__(self, k: DispatchKey, enabled: _bool) -> None: ...
def __enter__(self): ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...
# Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo:
def __init__(self, schema: _int) -> None: ...
@overload
def is_mutable(self) -> _bool: ...
@overload
def is_mutable(self, name: str) -> _bool: ...
def has_argument(self, name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp
@ -2431,7 +2424,7 @@ def _create_graph_by_tracing(
strict: Any,
force_outplace: Any,
self: Any = None,
argument_names: list[str] = [],
argument_names: list[str] = ...,
) -> tuple[Graph, Stack]: ...
def _tracer_warn_use_python(): ...
def _get_tracing_state() -> TracingState: ...
@ -2458,8 +2451,6 @@ class InferredType:
def success(self) -> _bool: ...
def reason(self) -> str: ...
R = TypeVar("R", bound=JitType)
class Type(JitType):
def str(self) -> _str: ...
def containedTypes(self) -> list[JitType]: ...
@ -2558,16 +2549,18 @@ class UnionType(JitType):
class ClassType(JitType):
def __init__(self, qualified_name: str) -> None: ...
def qualified_name(self) ->str: ...
def qualified_name(self) -> str: ...
class InterfaceType(JitType):
def __init__(self, qualified_name: str) -> None: ...
def getMethod(self, name: str) -> FunctionSchema | None: ...
def getMethodNames(self) -> list[str]: ...
class OptionalType(JitType, Generic[R]):
def __init__(self, a: JitType) -> None: ...
def getElementType(self) -> JitType: ...
JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001
class OptionalType(JitType, Generic[JitTypeT]):
def __init__(self, a: JitTypeT) -> None: ...
def getElementType(self) -> JitTypeT: ...
@staticmethod
def ofTensor() -> OptionalType: ...
@ -2681,17 +2674,18 @@ def _fuse_to_static_module(
# Defined in torch/csrc/fx/node.cpp
def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ...
def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ...
class _NodeBase:
_erased: _bool
_prev: FxNode
_next: FxNode
def __init__(
self,
graph: Any,
name: str,
op: str,
target: Any,
return_type: Any,
self,
graph: Any,
name: str,
op: str,
target: Any,
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...

View File

@ -3,7 +3,7 @@
import datetime
from enum import Enum
from types import TracebackType
from typing import Callable, Optional
from typing import Callable
class Aggregation(Enum):
VALUE = ...
@ -48,9 +48,9 @@ class _WaitCounterTracker:
def __enter__(self) -> None: ...
def __exit__(
self,
exec_type: Optional[type[BaseException]] = None,
exec_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None: ...
class _WaitCounter:

View File

@ -1,7 +1,8 @@
# ${generated_comment}
# mypy: disable-error-code="type-arg"
from typing import Literal, overload, Sequence
from collections.abc import Sequence
from typing import Literal, overload
from torch import memory_format, Tensor
from torch.types import _bool, _device, _dtype, _int, _size

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Literal, Optional
from typing import Literal
from typing_extensions import TypeAlias
from torch._C import device, dtype, layout
@ -73,7 +73,7 @@ class ProfilerConfig:
with_flops: bool,
with_modules: bool,
experimental_config: _ExperimentalConfig,
trace_id: Optional[str] = None,
trace_id: str | None = None,
) -> None: ...
class _ProfilerEvent:
@ -243,4 +243,4 @@ class _RecordFunctionFast:
keyword_values: dict | None = None,
) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
def __exit__(self, *exc_info: object) -> None: ...

View File

@ -1,31 +1,11 @@
# ${generated_comment}
# mypy: allow-untyped-defs
from typing import (
Any,
Callable,
ContextManager,
Final,
Iterator,
Literal,
NamedTuple,
NoReturn,
overload,
Sequence,
TypeVar,
)
from typing import Final, NoReturn
from typing_extensions import Self
from torch import (
contiguous_format,
Generator,
inf,
memory_format,
strided,
SymInt,
Tensor,
)
from torch.types import (
from torch import SymInt, Tensor
from torch.types import ( # noqa: F401
_bool,
_device,
_dtype,

View File

@ -1,7 +1,8 @@
# ${generated_comment}
# mypy: allow-untyped-defs
from typing import Any, Callable, Literal, overload, Sequence
from collections.abc import Sequence
from typing import Any, Callable, Literal, overload
from typing_extensions import TypeAlias
from torch import Tensor

View File

@ -2383,7 +2383,7 @@ class CudaNonDefaultStream:
device_type=deviceStream.device_type)
torch._C._cuda_setDevice(beforeDevice)
def __exit__(self, exec_type, exec_value, traceback):
def __exit__(self, exc_type, exc_value, traceback):
# After completing CUDA test load previously active streams on all
# CUDA devices.
beforeDevice = torch.cuda.current_device()
@ -2431,9 +2431,9 @@ class CudaMemoryLeakCheck:
driver_mem_allocated = bytes_total - bytes_free
self.driver_befores.append(driver_mem_allocated)
def __exit__(self, exec_type, exec_value, traceback):
def __exit__(self, exc_type, exc_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
if exc_type is not None:
return
# Compares caching allocator before/after statistics

View File

@ -5,19 +5,8 @@
# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
# classes/objects here, even though we are not injecting extra code into them at the moment.
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from collections.abc import Iterable, Iterator
from typing import Any, Callable, Literal, Optional, TypeVar, Union
from torch.utils.data import Dataset, default_collate, IterableDataset
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
@ -27,19 +16,19 @@ _T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
UNTRACABLE_DATAFRAME_PIPES: Any
class DataChunk(List[_T]):
items: List[_T]
class DataChunk(list[_T]):
items: list[_T]
def __init__(self, items: Iterable[_T]) -> None: ...
def as_str(self, indent: str = "") -> str: ...
def __iter__(self) -> Iterator[_T]: ...
def raw_iterator(self) -> Iterator[_T]: ...
class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
functions: Dict[str, Callable] = ...
reduce_ex_hook: Optional[Callable] = ...
getstate_hook: Optional[Callable] = ...
str_hook: Optional[Callable] = ...
repr_hook: Optional[Callable] = ...
functions: dict[str, Callable] = ...
reduce_ex_hook: Callable | None = ...
getstate_hook: Callable | None = ...
str_hook: Callable | None = ...
repr_hook: Callable | None = ...
def __getattr__(self, attribute_name: Any): ...
@classmethod
def register_function(cls, function_name: Any, function: Any) -> None: ...
@ -58,7 +47,7 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
${MapDataPipeMethods}
class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
functions: Dict[str, Callable] = ...
functions: dict[str, Callable] = ...
reduce_ex_hook: Optional[Callable] = ...
getstate_hook: Optional[Callable] = ...
str_hook: Optional[Callable] = ...