PEP585 update - torch/fx (#145166)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 19:32:07 -08:00
committed by PyTorch MergeBot
parent 6374332d33
commit 0b2a3687b9
57 changed files with 904 additions and 917 deletions

View File

@ -3,6 +3,7 @@
import builtins
import contextlib
import collections
import copy
import functools
import inspect
@ -2767,7 +2768,7 @@ class TestFX(JitTestCase):
return self.other(x)
traced = symbolic_trace(ReturnTypeModule())
self.assertIn("-> typing_List[str]", traced._code)
self.assertIn("-> list[str]", traced._code)
scripted = torch.jit.script(traced)
self.assertIn("-> List[str]", scripted.code)
@ -3566,8 +3567,8 @@ class TestFX(JitTestCase):
traced(x, y)
FileCheck().check("_Tuple[()]") \
.check("typing_Tuple[str,typing_Tuple[()]]") \
FileCheck().check("tuple[()]") \
.check("tuple[str,tuple[()]]") \
.run(traced.code)
scripted = torch.jit.script(traced)
@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
def _annotation_type_to_stable_str(self, t, sig_str):
_trivial_mappings = {
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
collections.abc.Iterator: 'Iterator',
}
_UNBOUND_TYPES = {
dict,
list,
tuple,
type,
typing.Callable,
typing.Dict,
typing.List,
typing.Tuple,
typing.Type,
typing.Union,
}
def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
if t is inspect.Signature.empty:
return ''
# Forward ref
if isinstance(t, str):
return f"'{t}'"
if recursive:
return t
else:
return f"'{t}'"
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
return t.__forward_arg__
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
return t.__forward_arg__
trivial_mappings = {
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
}
mapping = trivial_mappings.get(t, None)
mapping = self._trivial_mappings.get(t, None)
if mapping:
return mapping
@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
if all(isinstance(ct, typing.TypeVar) for ct in contained):
contained = []
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained]
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
origin = getattr(t, '__origin__', None)
if origin is None:
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
origin = t if t in self._UNBOUND_TYPES else origin
if origin in {tuple, typing.Tuple}:
return f'Tuple{contained_type_str}'
@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
# Annoying hack to detect Optional
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]'
return f'Union{contained_type_str}'
if origin in {dict, typing.Dict}:
return f'Dict{contained_type_str}'