[BE]: Apply PYI autofixes to various types (#107521)

Applies some autofixes from the ruff PYI rules to improve the typing of PyTorch. I haven't enabled most of these ruff rules yet as they do not have autofixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107521
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2023-08-20 02:42:18 +00:00
committed by PyTorch MergeBot
parent 24f0b552e1
commit b1e8e01e50
10 changed files with 20 additions and 22 deletions

View File

@ -1,5 +1,6 @@
import collections
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, overload
from typing_extensions import TypeAlias
import numpy as np
import google.protobuf.message
@ -10,9 +11,9 @@ from . import core
# pybind11 will automatically accept either Python str or bytes for C++ APIs
# that accept std::string.
_PybindStr = Union[str, bytes]
_PerOpEnginePrefType = Dict[int, Dict[str, List[str]]]
_EnginePrefType = Dict[int, List[str]]
_PybindStr: TypeAlias = Union[str, bytes]
_PerOpEnginePrefType: TypeAlias = Dict[int, Dict[str, List[str]]]
_EnginePrefType: TypeAlias = Dict[int, List[str]]
Int8Tensor = collections.namedtuple(
'Int8Tensor', ['data', 'scale', 'zero_point']
@ -52,7 +53,7 @@ class Workspace:
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, workspace: "Workspace") -> None: ...
def __init__(self, workspace: Workspace) -> None: ...
@property
def blobs(self) -> Dict[str, Blob]: ...
def create_blob(self, name: _PybindStr) -> Blob: ...
@ -86,7 +87,7 @@ class Workspace:
) -> bool: ...
def remove_blob(self, blob: Any) -> None: ...
current: "Workspace"
current: Workspace
class Argument:
@ -100,7 +101,7 @@ class Argument:
class OpSchema:
@staticmethod
def get(key: str) -> "OpSchema": ...
def get(key: str) -> OpSchema: ...
@property
def args(self) -> List[Argument]: ...
@property

View File

@ -10,13 +10,13 @@ class QueryTensorQparam:
def max(self) -> float: ...
class HistogramNetObserver:
pass
...
class OutputColumnMaxHistogramNetObserver:
pass
...
class RegisterQuantizationParamsWithHistogramNetObserver:
pass
...
def ClearNetObservers() -> None: ...
def ObserveMinMaxOfOutput(min_max_file_name: str, dump_freq: int = -1, delimiter: str = " ") -> None: ...

View File

@ -2,6 +2,7 @@ from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from torch._C import device, dtype, layout
from typing_extensions import TypeAlias
# defined in torch/csrc/profiler/python/init.cpp
@ -134,8 +135,8 @@ class _TensorMetadata:
@property
def strides(self) -> List[int]: ...
Scalar = Union[int, float, bool, complex]
Input = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
Scalar: TypeAlias = Union[int, float, bool, complex]
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
class _ExtraFields_TorchOp:
name: str

View File

@ -769,7 +769,7 @@ class SetVariable(VariableTracker):
def __hash__(self) -> int:
return hash(self.underlying_value)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, SetVariable.SetElement):
return False
if isinstance(self.vt, variables.TensorVariable):

View File

@ -1,7 +1,5 @@
import warnings
from typing import Any
import torch
__all__ = ["detect_anomaly", "set_detect_anomaly"]
@ -88,7 +86,7 @@ class detect_anomaly:
def __enter__(self) -> None:
torch.set_anomaly_enabled(True, self.check_nan)
def __exit__(self, *args: Any) -> None:
def __exit__(self, *args: object) -> None:
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
@ -117,5 +115,5 @@ class set_detect_anomaly:
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> None:
def __exit__(self, *args: object) -> None:
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)

View File

@ -224,7 +224,7 @@ class saved_tensors_hooks:
self.pack_hook, self.unpack_hook
)
def __exit__(self, *args: Any):
def __exit__(self, *args: object):
torch._C._autograd._pop_saved_tensors_default_hooks()

View File

@ -39,7 +39,7 @@ class Namespace(metaclass=abc.ABCMeta):
return self.id < other.id
return False
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, Namespace):
return self.id == other.id
return False

View File

@ -99,7 +99,7 @@ class TypePromotionRule(abc.ABC):
...
@abc.abstractmethod
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
...
def is_valid(self) -> bool:

View File

@ -91,7 +91,7 @@ class _Storage:
def __repr__(self) -> str:
return f"{hex(self.ptr):>18} ({self.allocation_id})"
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
def __hash__(self) -> int:

View File

@ -75,5 +75,3 @@ class Storage:
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
...
...