diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py index 963c6ef7302c..2eb2405535c9 100644 --- a/test/onnx/test_onnxscript_runtime.py +++ b/test/onnx/test_onnxscript_runtime.py @@ -2,7 +2,7 @@ """Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime.""" -from typing import List +from typing import Sequence import onnx_test_common import onnxscript @@ -90,7 +90,11 @@ class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime): @onnxscript.script(custom_opset) def layer_norm( - X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float + X, + axes: Sequence[int], + weight: FLOAT[...], + bias: FLOAT[...], + eps: float, ): mean = op.ReduceMean(X, axes=axes) D = X - mean # op.Sub(X, mean) diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 985166c48c66..eca96fe1137c 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -231,7 +231,7 @@ class TestPybindTypeCasters(common.TestCase): Our Pybind functions have a signature of the form `() -> return_type`. """ # Imports needed for the `eval` below. - from typing import List, Tuple # noqa: F401 + from typing import List, Tuple # noqa: F401, UP035 return eval(re.search("-> (.*)\n", func.__doc__).group(1)) diff --git a/test/test_fx.py b/test/test_fx.py index c702adbfaf33..92bbf709c2ea 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -22,7 +22,7 @@ import warnings from collections import namedtuple from copy import deepcopy from math import sqrt -from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, NamedTuple, Optional, Union import torch import torch.fx._pytree as fx_pytree @@ -2270,10 +2270,19 @@ class TestFX(JitTestCase): graph: torch.fx.Graph = torch.fx.Graph() x: torch.fx.Node = graph.create_node("placeholder", "x") b: torch.fx.Node = graph.create_node( - "call_function", target=torch.relu, args=(x,), type_expr=List[float] + "call_function", target=torch.relu, args=(x,), type_expr=list[float] ) output: torch.fx.Node = graph.output(b) + self.assertTrue('list[float]' in str(graph)) + + def test_typename_print_pre_pep585(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), + type_expr=typing.List[float]) # noqa: UP006 + output : torch.fx.Node = graph.output(b) + self.assertTrue("typing.List[float]" in str(graph)) def test_layout(self): @@ -2922,6 +2931,19 @@ class TestFX(JitTestCase): def forward(self, x: list[str]) -> list[str]: return self.other(x) + traced = symbolic_trace(ReturnTypeModule()) + self.assertIn("-> list[str]", traced._code) + scripted = torch.jit.script(traced) + self.assertIn("-> List[str]", scripted.code) + + def test_return_type_exists_pre_pep585(self): + class ReturnTypeModule(torch.nn.Module): + def other(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006 + return x + + def forward(self, x: typing.List[str]) -> typing.List[str]: # noqa: UP006 + return self.other(x) + traced = symbolic_trace(ReturnTypeModule()) self.assertIn("-> typing_List[str]", traced._code) scripted = torch.jit.script(traced) @@ -3735,7 +3757,7 @@ class TestFX(JitTestCase): @unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11") def test_annotations_empty_tuple(self): class Foo(torch.nn.Module): - def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]): + def forward(self, x: typing.Tuple[()], y: typing.Tuple[str, typing.Tuple[()]]): # noqa: UP006 return "foo" traced = torch.fx.symbolic_trace(Foo()) @@ -4320,10 +4342,10 @@ class TestFXAPIBackwardCompatibility(JitTestCase): tuple, type, typing.Callable, - typing.Dict, - typing.List, - typing.Tuple, - typing.Type, + typing.Dict, # noqa: UP006 + typing.List, # noqa: UP006 + typing.Tuple, # noqa: UP006 + typing.Type, # noqa: UP006 typing.Union, } diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 24a8a128d8fc..434de5243c13 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -12,7 +12,7 @@ import tempfile import typing import unittest from types import BuiltinFunctionType -from typing import Callable, List, NamedTuple, Optional, Union +from typing import Callable, NamedTuple, Optional, Union import torch import torch.fx.experimental.meta_tracer @@ -1548,25 +1548,25 @@ class {test_classname}(torch.nn.Module): (Optional[list[int]], list[int]), ] + [ # pre-PEP585 signatures - (typing.List[int], int), - (typing.List[int], create_type_hint([int, int])), - (typing.List[int], create_type_hint((int, int))), - (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), + (typing.List[int], int), # noqa: UP006 + (typing.List[int], create_type_hint([int, int])), # noqa: UP006 + (typing.List[int], create_type_hint((int, int))), # noqa: UP006 + (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), # noqa: UP006 ( - typing.List[torch.Tensor], + typing.List[torch.Tensor], # noqa: UP006 create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), ), - (typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), - (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), - (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), + (typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), # noqa: UP006 + (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), # noqa: UP006 + (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), # noqa: UP006 ( - typing.List[torch.Tensor], + typing.List[torch.Tensor], # noqa: UP006 create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), ), - (typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), - (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), - (Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), - (Optional[typing.List[int]], typing.List[int]), + (typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), # noqa: UP006 + (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), # noqa: UP006 + (Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]), # noqa: UP006 + (Optional[typing.List[int]], typing.List[int]), # noqa: UP006 ] for sig_type, arg_type in should_be_equal: @@ -1575,7 +1575,7 @@ class {test_classname}(torch.nn.Module): should_fail = [ (int, float), (Union[int, float], str), - (list[torch.Tensor], List[int]), + (list[torch.Tensor], typing.List[int]), # noqa: UP006 ] + [ # pre-PEP585 signatures (list[torch.Tensor], list[int]), diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 6aa932c01136..73c072948198 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Tuple, Union +from typing import Any, Union import torch from torch.utils._contextlib import ( @@ -386,7 +386,7 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager): """ - def __init__(self, tensors: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> None: + def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None: self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors assert isinstance(self.tensors, tuple) self.prev_versions = tuple(t._version for t in self.tensors) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 89f97b2419cb..548b4fda9206 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -455,9 +455,13 @@ class CodeGen: typename = _type_repr(o) - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + if origin_type := getattr(o, "__origin__", None): + # list[...], typing.List[...], TensorType[...] + + if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined] + # This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(origin_type, origin_type) + origin_typename = add_global(_type_repr(origin_type), origin_type) if hasattr(o, "__args__"): diff --git a/torch/fx/node.py b/torch/fx/node.py index 39c7f82d8c24..22e19ddbeae7 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -126,7 +126,9 @@ def _type_repr(obj: object) -> str: typically enough to uniquely identify a type. For everything else, we fall back on repr(obj). """ - if isinstance(obj, type): + # Extension: If we don't ignore GenericAlias then `list[int]` will print + # simply "list". + if isinstance(obj, type) and not isinstance(obj, types.GenericAlias): if obj.__module__ == "builtins": return obj.__qualname__ return f"{obj.__module__}.{obj.__qualname__}"