mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585: Add noqa to necessary tests (#146391)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146391 Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
b61032fcf7
commit
1f8ff94d4f
@ -2,7 +2,7 @@
|
||||
|
||||
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
|
||||
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
|
||||
import onnx_test_common
|
||||
import onnxscript
|
||||
@ -90,7 +90,11 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
|
||||
|
||||
@onnxscript.script(custom_opset)
|
||||
def layer_norm(
|
||||
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
|
||||
X,
|
||||
axes: Sequence[int],
|
||||
weight: FLOAT[...],
|
||||
bias: FLOAT[...],
|
||||
eps: float,
|
||||
):
|
||||
mean = op.ReduceMean(X, axes=axes)
|
||||
D = X - mean # op.Sub(X, mean)
|
||||
|
@ -231,7 +231,7 @@ class TestPybindTypeCasters(common.TestCase):
|
||||
Our Pybind functions have a signature of the form `() -> return_type`.
|
||||
"""
|
||||
# Imports needed for the `eval` below.
|
||||
from typing import List, Tuple # noqa: F401
|
||||
from typing import List, Tuple # noqa: F401, UP035
|
||||
|
||||
return eval(re.search("-> (.*)\n", func.__doc__).group(1))
|
||||
|
||||
|
@ -22,7 +22,7 @@ import warnings
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
from math import sqrt
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx._pytree as fx_pytree
|
||||
@ -2270,10 +2270,19 @@ class TestFX(JitTestCase):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
x: torch.fx.Node = graph.create_node("placeholder", "x")
|
||||
b: torch.fx.Node = graph.create_node(
|
||||
"call_function", target=torch.relu, args=(x,), type_expr=List[float]
|
||||
"call_function", target=torch.relu, args=(x,), type_expr=list[float]
|
||||
)
|
||||
output: torch.fx.Node = graph.output(b)
|
||||
|
||||
self.assertTrue('list[float]' in str(graph))
|
||||
|
||||
def test_typename_print_pre_pep585(self):
|
||||
graph : torch.fx.Graph = torch.fx.Graph()
|
||||
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
||||
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
|
||||
type_expr=typing.List[float]) # noqa: UP006
|
||||
output : torch.fx.Node = graph.output(b)
|
||||
|
||||
self.assertTrue("typing.List[float]" in str(graph))
|
||||
|
||||
def test_layout(self):
|
||||
@ -2922,6 +2931,19 @@ class TestFX(JitTestCase):
|
||||
def forward(self, x: list[str]) -> list[str]:
|
||||
return self.other(x)
|
||||
|
||||
traced = symbolic_trace(ReturnTypeModule())
|
||||
self.assertIn("-> list[str]", traced._code)
|
||||
scripted = torch.jit.script(traced)
|
||||
self.assertIn("-> List[str]", scripted.code)
|
||||
|
||||
def test_return_type_exists_pre_pep585(self):
|
||||
class ReturnTypeModule(torch.nn.Module):
|
||||
def other(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
|
||||
return x
|
||||
|
||||
def forward(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
|
||||
return self.other(x)
|
||||
|
||||
traced = symbolic_trace(ReturnTypeModule())
|
||||
self.assertIn("-> typing_List[str]", traced._code)
|
||||
scripted = torch.jit.script(traced)
|
||||
@ -3735,7 +3757,7 @@ class TestFX(JitTestCase):
|
||||
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
|
||||
def test_annotations_empty_tuple(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
|
||||
def forward(self, x: typing.Tuple[()], y: typing.Tuple[str, typing.Tuple[()]]): # noqa: UP006
|
||||
return "foo"
|
||||
|
||||
traced = torch.fx.symbolic_trace(Foo())
|
||||
@ -4320,10 +4342,10 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
||||
tuple,
|
||||
type,
|
||||
typing.Callable,
|
||||
typing.Dict,
|
||||
typing.List,
|
||||
typing.Tuple,
|
||||
typing.Type,
|
||||
typing.Dict, # noqa: UP006
|
||||
typing.List, # noqa: UP006
|
||||
typing.Tuple, # noqa: UP006
|
||||
typing.Type, # noqa: UP006
|
||||
typing.Union,
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ import tempfile
|
||||
import typing
|
||||
import unittest
|
||||
from types import BuiltinFunctionType
|
||||
from typing import Callable, List, NamedTuple, Optional, Union
|
||||
from typing import Callable, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.experimental.meta_tracer
|
||||
@ -1548,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
|
||||
(Optional[list[int]], list[int]),
|
||||
] + [
|
||||
# pre-PEP585 signatures
|
||||
(typing.List[int], int),
|
||||
(typing.List[int], create_type_hint([int, int])),
|
||||
(typing.List[int], create_type_hint((int, int))),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
||||
(typing.List[int], int), # noqa: UP006
|
||||
(typing.List[int], create_type_hint([int, int])), # noqa: UP006
|
||||
(typing.List[int], create_type_hint((int, int))), # noqa: UP006
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), # noqa: UP006
|
||||
(
|
||||
typing.List[torch.Tensor],
|
||||
typing.List[torch.Tensor], # noqa: UP006
|
||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||
),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), # noqa: UP006
|
||||
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), # noqa: UP006
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), # noqa: UP006
|
||||
(
|
||||
typing.List[torch.Tensor],
|
||||
typing.List[torch.Tensor], # noqa: UP006
|
||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||
),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
|
||||
(Optional[typing.List[int]], typing.List[int]),
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), # noqa: UP006
|
||||
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), # noqa: UP006
|
||||
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), # noqa: UP006
|
||||
(Optional[typing.List[int]], typing.List[int]), # noqa: UP006
|
||||
]
|
||||
|
||||
for sig_type, arg_type in should_be_equal:
|
||||
@ -1575,7 +1575,7 @@ class {test_classname}(torch.nn.Module):
|
||||
should_fail = [
|
||||
(int, float),
|
||||
(Union[int, float], str),
|
||||
(list[torch.Tensor], List[int]),
|
||||
(list[torch.Tensor], typing.List[int]), # noqa: UP006
|
||||
] + [
|
||||
# pre-PEP585 signatures
|
||||
(list[torch.Tensor], list[int]),
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Tuple, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._contextlib import (
|
||||
@ -386,7 +386,7 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
|
||||
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
|
||||
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
|
||||
assert isinstance(self.tensors, tuple)
|
||||
self.prev_versions = tuple(t._version for t in self.tensors)
|
||||
|
@ -455,9 +455,13 @@ class CodeGen:
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, "__origin__"):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
if origin_type := getattr(o, "__origin__", None):
|
||||
# list[...], typing.List[...], TensorType[...]
|
||||
|
||||
if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined]
|
||||
# This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(origin_type, origin_type)
|
||||
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, "__args__"):
|
||||
|
@ -126,7 +126,9 @@ def _type_repr(obj: object) -> str:
|
||||
typically enough to uniquely identify a type. For everything
|
||||
else, we fall back on repr(obj).
|
||||
"""
|
||||
if isinstance(obj, type):
|
||||
# Extension: If we don't ignore GenericAlias then `list[int]` will print
|
||||
# simply "list".
|
||||
if isinstance(obj, type) and not isinstance(obj, types.GenericAlias):
|
||||
if obj.__module__ == "builtins":
|
||||
return obj.__qualname__
|
||||
return f"{obj.__module__}.{obj.__qualname__}"
|
||||
|
Reference in New Issue
Block a user