[BE][Ez]: Enable ruff PYI019 (#127684)

Tells pytorch to use typing_extensions.Self when it's able to.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127684
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2024-06-02 13:38:33 +00:00
committed by PyTorch MergeBot
parent 67ef2683d9
commit c2547dfcc3
3 changed files with 26 additions and 36 deletions

View File

@ -68,7 +68,6 @@ ignore = [
"PERF401",
"PERF403",
# these ignores are from PYI; please fix!
"PYI019",
"PYI024",
"PYI036",
"PYI041",

View File

@ -1,14 +1,5 @@
from typing import (
Any,
Iterable,
NamedTuple,
Optional,
overload,
Sequence,
Tuple,
TypeVar,
Union,
)
from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union
from typing_extensions import Self
from torch import Tensor
@ -24,8 +15,6 @@ class PackedSequence_(NamedTuple):
def bind(optional: Any, fn: Any): ...
_T = TypeVar("_T")
class PackedSequence(PackedSequence_):
def __new__(
cls,
@ -34,39 +23,39 @@ class PackedSequence(PackedSequence_):
sorted_indices: Optional[Tensor] = ...,
unsorted_indices: Optional[Tensor] = ...,
) -> Self: ...
def pin_memory(self: _T) -> _T: ...
def cuda(self: _T, *args: Any, **kwargs: Any) -> _T: ...
def cpu(self: _T) -> _T: ...
def double(self: _T) -> _T: ...
def float(self: _T) -> _T: ...
def half(self: _T) -> _T: ...
def long(self: _T) -> _T: ...
def int(self: _T) -> _T: ...
def short(self: _T) -> _T: ...
def char(self: _T) -> _T: ...
def byte(self: _T) -> _T: ...
def pin_memory(self: Self) -> Self: ...
def cuda(self: Self, *args: Any, **kwargs: Any) -> Self: ...
def cpu(self: Self) -> Self: ...
def double(self: Self) -> Self: ...
def float(self: Self) -> Self: ...
def half(self: Self) -> Self: ...
def long(self: Self) -> Self: ...
def int(self: Self) -> Self: ...
def short(self: Self) -> Self: ...
def char(self: Self) -> Self: ...
def byte(self: Self) -> Self: ...
@overload
def to(
self: _T,
self: Self,
dtype: _dtype,
non_blocking: bool = False,
copy: bool = False,
) -> _T: ...
) -> Self: ...
@overload
def to(
self: _T,
self: Self,
device: Optional[DeviceLikeType] = None,
dtype: Optional[_dtype] = None,
non_blocking: bool = False,
copy: bool = False,
) -> _T: ...
) -> Self: ...
@overload
def to(
self: _T,
self: Self,
other: Tensor,
non_blocking: bool = False,
copy: bool = False,
) -> _T: ...
) -> Self: ...
@property
def is_cuda(self) -> bool: ...
def is_pinned(self) -> bool: ...

View File

@ -21,6 +21,8 @@ from typing import (
TypeVar,
)
from typing_extensions import Self
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
@ -92,24 +94,24 @@ class Diagnostic:
)
return sarif_result
def with_location(self: _Diagnostic, location: infra.Location) -> _Diagnostic:
def with_location(self: Self, location: infra.Location) -> Self:
"""Adds a location to the diagnostic."""
self.locations.append(location)
return self
def with_thread_flow_location(
self: _Diagnostic, location: infra.ThreadFlowLocation
) -> _Diagnostic:
self: Self, location: infra.ThreadFlowLocation
) -> Self:
"""Adds a thread flow location to the diagnostic."""
self.thread_flow_locations.append(location)
return self
def with_stack(self: _Diagnostic, stack: infra.Stack) -> _Diagnostic:
def with_stack(self: Self, stack: infra.Stack) -> Self:
"""Adds a stack to the diagnostic."""
self.stacks.append(stack)
return self
def with_graph(self: _Diagnostic, graph: infra.Graph) -> _Diagnostic:
def with_graph(self: Self, graph: infra.Graph) -> Self:
"""Adds a graph to the diagnostic."""
self.graphs.append(graph)
return self