mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/fx (#145166)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
6374332d33
commit
0b2a3687b9
@ -3,6 +3,7 @@
|
||||
|
||||
import builtins
|
||||
import contextlib
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
@ -2767,7 +2768,7 @@ class TestFX(JitTestCase):
|
||||
return self.other(x)
|
||||
|
||||
traced = symbolic_trace(ReturnTypeModule())
|
||||
self.assertIn("-> typing_List[str]", traced._code)
|
||||
self.assertIn("-> list[str]", traced._code)
|
||||
scripted = torch.jit.script(traced)
|
||||
self.assertIn("-> List[str]", scripted.code)
|
||||
|
||||
@ -3566,8 +3567,8 @@ class TestFX(JitTestCase):
|
||||
|
||||
traced(x, y)
|
||||
|
||||
FileCheck().check("_Tuple[()]") \
|
||||
.check("typing_Tuple[str,typing_Tuple[()]]") \
|
||||
FileCheck().check("tuple[()]") \
|
||||
.check("tuple[str,tuple[()]]") \
|
||||
.run(traced.code)
|
||||
|
||||
scripted = torch.jit.script(traced)
|
||||
@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
||||
|
||||
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
|
||||
|
||||
def _annotation_type_to_stable_str(self, t, sig_str):
|
||||
_trivial_mappings = {
|
||||
str : 'str',
|
||||
int : 'int',
|
||||
float: 'float',
|
||||
bool: 'bool',
|
||||
torch.dtype: 'torch.dtype',
|
||||
torch.Tensor: 'torch.Tensor',
|
||||
torch.device: 'torch.device',
|
||||
torch.memory_format: 'torch.memory_format',
|
||||
slice: 'slice',
|
||||
torch.nn.Module: 'torch.nn.modules.module.Module',
|
||||
torch.fx.Graph : 'torch.fx.graph.Graph',
|
||||
torch.fx.Node : 'torch.fx.node.Node',
|
||||
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
|
||||
torch.fx.node.Target : 'torch.fx.node.Target',
|
||||
torch.fx.node.Argument : 'torch.fx.node.Argument',
|
||||
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
|
||||
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
|
||||
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
|
||||
Ellipsis : '...',
|
||||
typing.Any: 'Any',
|
||||
type(None): 'NoneType',
|
||||
None: 'None',
|
||||
typing.Iterator: 'Iterator',
|
||||
collections.abc.Iterator: 'Iterator',
|
||||
}
|
||||
|
||||
_UNBOUND_TYPES = {
|
||||
dict,
|
||||
list,
|
||||
tuple,
|
||||
type,
|
||||
typing.Callable,
|
||||
typing.Dict,
|
||||
typing.List,
|
||||
typing.Tuple,
|
||||
typing.Type,
|
||||
typing.Union,
|
||||
}
|
||||
|
||||
def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
|
||||
if t is inspect.Signature.empty:
|
||||
return ''
|
||||
|
||||
# Forward ref
|
||||
if isinstance(t, str):
|
||||
return f"'{t}'"
|
||||
if recursive:
|
||||
return t
|
||||
else:
|
||||
return f"'{t}'"
|
||||
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
|
||||
return t.__forward_arg__
|
||||
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
|
||||
return t.__forward_arg__
|
||||
|
||||
trivial_mappings = {
|
||||
str : 'str',
|
||||
int : 'int',
|
||||
float: 'float',
|
||||
bool: 'bool',
|
||||
torch.dtype: 'torch.dtype',
|
||||
torch.Tensor: 'torch.Tensor',
|
||||
torch.device: 'torch.device',
|
||||
torch.memory_format: 'torch.memory_format',
|
||||
slice: 'slice',
|
||||
torch.nn.Module: 'torch.nn.modules.module.Module',
|
||||
torch.fx.Graph : 'torch.fx.graph.Graph',
|
||||
torch.fx.Node : 'torch.fx.node.Node',
|
||||
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
|
||||
torch.fx.node.Target : 'torch.fx.node.Target',
|
||||
torch.fx.node.Argument : 'torch.fx.node.Argument',
|
||||
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
|
||||
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
|
||||
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
|
||||
Ellipsis : '...',
|
||||
typing.Any: 'Any',
|
||||
type(None): 'NoneType',
|
||||
None: 'None',
|
||||
typing.Iterator: 'Iterator',
|
||||
}
|
||||
|
||||
mapping = trivial_mappings.get(t, None)
|
||||
mapping = self._trivial_mappings.get(t, None)
|
||||
if mapping:
|
||||
return mapping
|
||||
|
||||
@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
||||
if all(isinstance(ct, typing.TypeVar) for ct in contained):
|
||||
contained = []
|
||||
|
||||
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
|
||||
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained]
|
||||
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
|
||||
|
||||
|
||||
origin = getattr(t, '__origin__', None)
|
||||
if origin is None:
|
||||
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
|
||||
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
|
||||
origin = t if t in self._UNBOUND_TYPES else origin
|
||||
|
||||
if origin in {tuple, typing.Tuple}:
|
||||
return f'Tuple{contained_type_str}'
|
||||
@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
||||
# Annoying hack to detect Optional
|
||||
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
|
||||
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
|
||||
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
|
||||
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]'
|
||||
return f'Union{contained_type_str}'
|
||||
if origin in {dict, typing.Dict}:
|
||||
return f'Dict{contained_type_str}'
|
||||
|
Reference in New Issue
Block a user