[JIT] Allow tuple and list generics (#98703)

As in Python-3.9+ `Dict`, `List`, and `Tuple` from `typing` module are deprecated in favor of their `builtins` counterparts, see [PEP 585](https://peps.python.org/pep-0585/)

Test plan: Run:
```
import torch
from typing import Union

@torch.jit.script
def to_tuple(v: Union[int, tuple[int, int]]) -> tuple[int, int]:
    """Converts int or tuple to tuple of ints."""
    if torch.jit.isinstance(v, int):
        return v, v
    else:
        return v

print(to_tuple(1), to_tuple((3, 4)))
```

It's almost impossible to add test to an existing CI, as test script will not be parseable by Python-3.8, which is a oldest supported Python version

Fixes https://github.com/pytorch/pytorch/issues/98521

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98703
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga
2023-04-09 22:58:58 +00:00
committed by PyTorch MergeBot
parent 2400cb1d57
commit 7e0c26d4d8

View File

@ -79,6 +79,9 @@ class SourceLoader:
loader = SourceLoader()
IS_PY39_PLUS = sys.version_info >= (3, 9)
def createResolutionCallbackFromEnv(lookup_base):
"""
Creates a resolution callback that will look up qualified names in an
@ -339,7 +342,7 @@ def get_annotation_str(annotation):
return ".".join([get_annotation_str(annotation.value), annotation.attr])
elif isinstance(annotation, ast.Subscript):
# In Python3.9+ subscript indicies are not wrapped in ast.Index
subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore[attr-defined]
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
elif isinstance(annotation, ast.Tuple):
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
@ -983,10 +986,11 @@ def is_tuple(ann) -> bool:
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
if not hasattr(ann, "__module__"):
return False
return ann.__module__ == "typing" and (
getattr(ann, "__origin__", None) is Tuple
or getattr(ann, "__origin__", None) is tuple
)
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
return True
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
def is_list(ann) -> bool:
@ -995,10 +999,11 @@ def is_list(ann) -> bool:
if not hasattr(ann, "__module__"):
return False
return ann.__module__ == "typing" and (
getattr(ann, "__origin__", None) is List
or getattr(ann, "__origin__", None) is list
)
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
return True
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
def is_dict(ann) -> bool:
@ -1007,10 +1012,11 @@ def is_dict(ann) -> bool:
if not hasattr(ann, "__module__"):
return False
return ann.__module__ == "typing" and (
getattr(ann, "__origin__", None) is Dict
or getattr(ann, "__origin__", None) is dict
)
ann_origin = getattr(ann, "__origin__", None)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
return True
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
def is_union(ann):