PEP585: Add noqa to necessary tests (#146391)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146391
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-02-06 07:23:33 -08:00
committed by PyTorch MergeBot
parent b61032fcf7
commit 1f8ff94d4f
7 changed files with 63 additions and 31 deletions

View File

@ -2,7 +2,7 @@
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
from typing import List
from typing import Sequence
import onnx_test_common
import onnxscript
@ -90,7 +90,11 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
@onnxscript.script(custom_opset)
def layer_norm(
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
X,
axes: Sequence[int],
weight: FLOAT[...],
bias: FLOAT[...],
eps: float,
):
mean = op.ReduceMean(X, axes=axes)
D = X - mean # op.Sub(X, mean)

View File

@ -231,7 +231,7 @@ class TestPybindTypeCasters(common.TestCase):
Our Pybind functions have a signature of the form `() -> return_type`.
"""
# Imports needed for the `eval` below.
from typing import List, Tuple # noqa: F401
from typing import List, Tuple # noqa: F401, UP035
return eval(re.search("-> (.*)\n", func.__doc__).group(1))

View File

@ -22,7 +22,7 @@ import warnings
from collections import namedtuple
from copy import deepcopy
from math import sqrt
from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, NamedTuple, Optional, Union
import torch
import torch.fx._pytree as fx_pytree
@ -2270,10 +2270,19 @@ class TestFX(JitTestCase):
graph: torch.fx.Graph = torch.fx.Graph()
x: torch.fx.Node = graph.create_node("placeholder", "x")
b: torch.fx.Node = graph.create_node(
"call_function", target=torch.relu, args=(x,), type_expr=List[float]
"call_function", target=torch.relu, args=(x,), type_expr=list[float]
)
output: torch.fx.Node = graph.output(b)
self.assertTrue('list[float]' in str(graph))
def test_typename_print_pre_pep585(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
type_expr=typing.List[float]) # noqa: UP006
output : torch.fx.Node = graph.output(b)
self.assertTrue("typing.List[float]" in str(graph))
def test_layout(self):
@ -2922,6 +2931,19 @@ class TestFX(JitTestCase):
def forward(self, x: list[str]) -> list[str]:
return self.other(x)
traced = symbolic_trace(ReturnTypeModule())
self.assertIn("-> list[str]", traced._code)
scripted = torch.jit.script(traced)
self.assertIn("-> List[str]", scripted.code)
def test_return_type_exists_pre_pep585(self):
class ReturnTypeModule(torch.nn.Module):
def other(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
return x
def forward(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006
return self.other(x)
traced = symbolic_trace(ReturnTypeModule())
self.assertIn("-> typing_List[str]", traced._code)
scripted = torch.jit.script(traced)
@ -3735,7 +3757,7 @@ class TestFX(JitTestCase):
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
def test_annotations_empty_tuple(self):
class Foo(torch.nn.Module):
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
def forward(self, x: typing.Tuple[()], y: typing.Tuple[str, typing.Tuple[()]]): # noqa: UP006
return "foo"
traced = torch.fx.symbolic_trace(Foo())
@ -4320,10 +4342,10 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
tuple,
type,
typing.Callable,
typing.Dict,
typing.List,
typing.Tuple,
typing.Type,
typing.Dict, # noqa: UP006
typing.List, # noqa: UP006
typing.Tuple, # noqa: UP006
typing.Type, # noqa: UP006
typing.Union,
}

View File

@ -12,7 +12,7 @@ import tempfile
import typing
import unittest
from types import BuiltinFunctionType
from typing import Callable, List, NamedTuple, Optional, Union
from typing import Callable, NamedTuple, Optional, Union
import torch
import torch.fx.experimental.meta_tracer
@ -1548,25 +1548,25 @@ class {test_classname}(torch.nn.Module):
(Optional[list[int]], list[int]),
] + [
# pre-PEP585 signatures
(typing.List[int], int),
(typing.List[int], create_type_hint([int, int])),
(typing.List[int], create_type_hint((int, int))),
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(typing.List[int], int), # noqa: UP006
(typing.List[int], create_type_hint([int, int])), # noqa: UP006
(typing.List[int], create_type_hint((int, int))), # noqa: UP006
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), # noqa: UP006
(
typing.List[torch.Tensor],
typing.List[torch.Tensor], # noqa: UP006
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), # noqa: UP006
(typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), # noqa: UP006
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), # noqa: UP006
(
typing.List[torch.Tensor],
typing.List[torch.Tensor], # noqa: UP006
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),
(Optional[typing.List[int]], typing.List[int]),
(typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), # noqa: UP006
(typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), # noqa: UP006
(Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), # noqa: UP006
(Optional[typing.List[int]], typing.List[int]), # noqa: UP006
]
for sig_type, arg_type in should_be_equal:
@ -1575,7 +1575,7 @@ class {test_classname}(torch.nn.Module):
should_fail = [
(int, float),
(Union[int, float], str),
(list[torch.Tensor], List[int]),
(list[torch.Tensor], typing.List[int]), # noqa: UP006
] + [
# pre-PEP585 signatures
(list[torch.Tensor], list[int]),

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Tuple, Union
from typing import Any, Union
import torch
from torch.utils._contextlib import (
@ -386,7 +386,7 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
"""
def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None:
def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
assert isinstance(self.tensors, tuple)
self.prev_versions = tuple(t._version for t in self.tensors)

View File

@ -455,9 +455,13 @@ class CodeGen:
typename = _type_repr(o)
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
if origin_type := getattr(o, "__origin__", None):
# list[...], typing.List[...], TensorType[...]
if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined]
# This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(origin_type, origin_type)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, "__args__"):

View File

@ -126,7 +126,9 @@ def _type_repr(obj: object) -> str:
typically enough to uniquely identify a type. For everything
else, we fall back on repr(obj).
"""
if isinstance(obj, type):
# Extension: If we don't ignore GenericAlias then `list[int]` will print
# simply "list".
if isinstance(obj, type) and not isinstance(obj, types.GenericAlias):
if obj.__module__ == "builtins":
return obj.__qualname__
return f"{obj.__module__}.{obj.__qualname__}"