mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Apply PYI autofixes to various types (#107521)
Applies some autofixes from the ruff PYI rules to improve the typing of PyTorch. I haven't enabled most of these ruff rules yet as they do not have autofixes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107521 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
24f0b552e1
commit
b1e8e01e50
@ -1,5 +1,6 @@
|
||||
import collections
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union, overload
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import google.protobuf.message
|
||||
@ -10,9 +11,9 @@ from . import core
|
||||
|
||||
# pybind11 will automatically accept either Python str or bytes for C++ APIs
|
||||
# that accept std::string.
|
||||
_PybindStr = Union[str, bytes]
|
||||
_PerOpEnginePrefType = Dict[int, Dict[str, List[str]]]
|
||||
_EnginePrefType = Dict[int, List[str]]
|
||||
_PybindStr: TypeAlias = Union[str, bytes]
|
||||
_PerOpEnginePrefType: TypeAlias = Dict[int, Dict[str, List[str]]]
|
||||
_EnginePrefType: TypeAlias = Dict[int, List[str]]
|
||||
|
||||
Int8Tensor = collections.namedtuple(
|
||||
'Int8Tensor', ['data', 'scale', 'zero_point']
|
||||
@ -52,7 +53,7 @@ class Workspace:
|
||||
@overload
|
||||
def __init__(self) -> None: ...
|
||||
@overload
|
||||
def __init__(self, workspace: "Workspace") -> None: ...
|
||||
def __init__(self, workspace: Workspace) -> None: ...
|
||||
@property
|
||||
def blobs(self) -> Dict[str, Blob]: ...
|
||||
def create_blob(self, name: _PybindStr) -> Blob: ...
|
||||
@ -86,7 +87,7 @@ class Workspace:
|
||||
) -> bool: ...
|
||||
def remove_blob(self, blob: Any) -> None: ...
|
||||
|
||||
current: "Workspace"
|
||||
current: Workspace
|
||||
|
||||
|
||||
class Argument:
|
||||
@ -100,7 +101,7 @@ class Argument:
|
||||
|
||||
class OpSchema:
|
||||
@staticmethod
|
||||
def get(key: str) -> "OpSchema": ...
|
||||
def get(key: str) -> OpSchema: ...
|
||||
@property
|
||||
def args(self) -> List[Argument]: ...
|
||||
@property
|
||||
|
@ -10,13 +10,13 @@ class QueryTensorQparam:
|
||||
def max(self) -> float: ...
|
||||
|
||||
class HistogramNetObserver:
|
||||
pass
|
||||
...
|
||||
|
||||
class OutputColumnMaxHistogramNetObserver:
|
||||
pass
|
||||
...
|
||||
|
||||
class RegisterQuantizationParamsWithHistogramNetObserver:
|
||||
pass
|
||||
...
|
||||
|
||||
def ClearNetObservers() -> None: ...
|
||||
def ObserveMinMaxOfOutput(min_max_file_name: str, dump_freq: int = -1, delimiter: str = " ") -> None: ...
|
||||
|
@ -2,6 +2,7 @@ from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from torch._C import device, dtype, layout
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
# defined in torch/csrc/profiler/python/init.cpp
|
||||
|
||||
@ -134,8 +135,8 @@ class _TensorMetadata:
|
||||
@property
|
||||
def strides(self) -> List[int]: ...
|
||||
|
||||
Scalar = Union[int, float, bool, complex]
|
||||
Input = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
|
||||
Scalar: TypeAlias = Union[int, float, bool, complex]
|
||||
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
|
||||
|
||||
class _ExtraFields_TorchOp:
|
||||
name: str
|
||||
|
@ -769,7 +769,7 @@ class SetVariable(VariableTracker):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.underlying_value)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SetVariable.SetElement):
|
||||
return False
|
||||
if isinstance(self.vt, variables.TensorVariable):
|
||||
|
@ -1,7 +1,5 @@
|
||||
import warnings
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["detect_anomaly", "set_detect_anomaly"]
|
||||
@ -88,7 +86,7 @@ class detect_anomaly:
|
||||
def __enter__(self) -> None:
|
||||
torch.set_anomaly_enabled(True, self.check_nan)
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
def __exit__(self, *args: object) -> None:
|
||||
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
|
||||
|
||||
|
||||
@ -117,5 +115,5 @@ class set_detect_anomaly:
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
def __exit__(self, *args: object) -> None:
|
||||
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
|
||||
|
@ -224,7 +224,7 @@ class saved_tensors_hooks:
|
||||
self.pack_hook, self.unpack_hook
|
||||
)
|
||||
|
||||
def __exit__(self, *args: Any):
|
||||
def __exit__(self, *args: object):
|
||||
torch._C._autograd._pop_saved_tensors_default_hooks()
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ class Namespace(metaclass=abc.ABCMeta):
|
||||
return self.id < other.id
|
||||
return False
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Namespace):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
@ -99,7 +99,7 @@ class TypePromotionRule(abc.ABC):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
...
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
|
@ -91,7 +91,7 @@ class _Storage:
|
||||
def __repr__(self) -> str:
|
||||
return f"{hex(self.ptr):>18} ({self.allocation_id})"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
@ -75,5 +75,3 @@ class Storage:
|
||||
|
||||
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
...
|
||||
|
Reference in New Issue
Block a user