Fix more JIT tests under Python-3.9 (#51182)

Summary:
Mostly replace `global Foo` with `make_global(Foo)`
The only real fix is generating Subscript annotation, which is a follow up from https://github.com/pytorch/pytorch/pull/48676

Fixes https://github.com/pytorch/pytorch/issues/49617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51182

Reviewed By: gmagogsfm

Differential Revision: D26095244

Pulled By: malfet

fbshipit-source-id: 0e043d9a2cf43fff71dfbb341f708cd7af87c39a
This commit is contained in:
Nikita Shulga
2021-01-27 10:49:10 -08:00
committed by Facebook GitHub Bot
parent 9b6d463704
commit 00adc7b07f
5 changed files with 62 additions and 69 deletions

View File

@ -11,7 +11,7 @@ from typing import Any
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import JitTestCase, make_global
import torch.testing._internal.jit_utils
from torch.testing._internal.common_utils import IS_SANDCASTLE
from typing import List, Tuple, Iterable, Optional, Dict
@ -143,12 +143,12 @@ class TestClassType(JitTestCase):
self.attr = x
def test_class_type_as_param(self):
global FooTest # see [local resolution in python]
class FooTest(object): # noqa: B903
def __init__(self, x):
self.attr = x
make_global(FooTest) # see [local resolution in python]
@torch.jit.script
def fn(foo: FooTest) -> torch.Tensor:
return foo.attr
@ -279,13 +279,13 @@ class TestClassType(JitTestCase):
self.assertEqual(2 * input, output)
def test_python_interop(self):
global Foo # see [local resolution in python]
class Foo(object): # noqa: B903
def __init__(self, x, y):
self.x = x
self.y = y
make_global(Foo) # see [local resolution in python]
@torch.jit.script
def use_foo(foo: Foo) -> Foo:
return foo
@ -305,13 +305,13 @@ class TestClassType(JitTestCase):
self.assertEqual(y, f2.y)
def test_class_specialization(self):
global Foo # see [local resolution in python]
class Foo(object): # noqa: B903
def __init__(self, x, y):
self.x = x
self.y = y
make_global(Foo) # see [local resolution in python]
def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor:
a, b = tup
return foo.x + foo2.y + a.x + b.y
@ -329,8 +329,6 @@ class TestClassType(JitTestCase):
FileCheck().check_count("prim::GetAttr", 4).run(graphstr)
def test_class_sorting(self):
global Foo # see [local resolution in python]
class Foo(object): # noqa: B903
def __init__(self, x: int) -> None:
self.x = x
@ -342,6 +340,8 @@ class TestClassType(JitTestCase):
def getVal(self):
return self.x
make_global(Foo) # see [local resolution in python]
def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]:
li_sorted = sorted(li)
ret_sorted = torch.jit.annotate(List[int], [])
@ -500,8 +500,6 @@ class TestClassType(JitTestCase):
self.assertEqual(3 * input, output)
def test_interface(self):
global Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2
@torch.jit.script
class Foo(object):
def __init__(self):
@ -571,6 +569,8 @@ class TestClassType(JitTestCase):
def two(self, x: int) -> int:
return 3
make_global(Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2)
def use_them(x):
a = Foo()
b = Bar()
@ -652,8 +652,6 @@ class TestClassType(JitTestCase):
# NamedTuple inheritance errors
def test_overloaded_fn(self):
global Foo, MyClass # see [local resolution in python]
@torch.jit.script
class Foo(object):
def __init__(self, x):
@ -673,6 +671,8 @@ class TestClassType(JitTestCase):
a = Foo(torch.ones([3, 3]))
return len(a), -a * torch.zeros([3, 3])
make_global(Foo) # see [local resolution in python]
self.checkScript(test_overload, ())
# unary ops tested above
@ -737,6 +737,8 @@ class TestClassType(JitTestCase):
return self.x * val * 3
make_global(Foo) # see [local resolution in python]
def add():
return MyClass(4) + 3
def sub(): # noqa: E306
@ -787,8 +789,6 @@ class TestClassType(JitTestCase):
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
def test_cast_overloads(self):
global Foo # see [local resolution in python]
@torch.jit.script
class Foo(object):
def __init__(self, val: float) -> None:
@ -806,6 +806,8 @@ class TestClassType(JitTestCase):
def __str__(self):
return str(self.val)
make_global(Foo) # see [local resolution in python]
def test(foo: Foo) -> Tuple[int, float, bool]:
if foo:
pass
@ -914,8 +916,6 @@ class TestClassType(JitTestCase):
self.assertEqual(m.w, m_loaded.w)
def test_py_class_to_ivalue_missing_attribute(self):
global Foo # see [local resolution in python]
class Foo(object):
i : int
f : float
@ -924,6 +924,8 @@ class TestClassType(JitTestCase):
self.i = i
self.f = f
make_global(Foo) # see [local resolution in python]
@torch.jit.script
def test_fn(x : Foo) -> float:
return x.i + x.f
@ -1132,8 +1134,6 @@ class TestClassType(JitTestCase):
"""
Test static methods on class types.
"""
global ClassWithStaticMethod
@torch.jit.script
class ClassWithStaticMethod:
def __init__(self, a: int, b: int):
@ -1164,14 +1164,14 @@ class TestClassType(JitTestCase):
def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
return ClassWithStaticMethod.create_from(a, b)
make_global(ClassWithStaticMethod)
self.checkScript(test_function, (1, 2))
def test_classmethod(self):
"""
Test classmethods on class types.
"""
global ClassWithClassMethod
@torch.jit.script
class ClassWithClassMethod:
def __init__(self, a: int):
@ -1184,6 +1184,8 @@ class TestClassType(JitTestCase):
def create(cls, a: int) -> 'ClassWithClassMethod':
return cls(a)
make_global(ClassWithClassMethod)
def test_function(a: int) -> 'ClassWithClassMethod':
x = ClassWithClassMethod(a)
# Support calling classmethod with an instance

View File

@ -9,7 +9,7 @@ from typing import Any, List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@ -18,24 +18,20 @@ if __name__ == '__main__':
class TestEnum(JitTestCase):
def test_enum_value_types(self):
global IntEnum
class IntEnum(Enum):
FOO = 1
BAR = 2
global FloatEnum
class FloatEnum(Enum):
FOO = 1.2
BAR = 2.3
global StringEnum
class StringEnum(Enum):
FOO = "foo as in foo bar"
BAR = "bar as in foo bar"
make_global(IntEnum, FloatEnum, StringEnum)
@torch.jit.script
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name)
@ -46,12 +42,12 @@ class TestEnum(JitTestCase):
.check("StringEnum") \
.run(str(supported_enum_types.graph))
global TensorEnum
class TensorEnum(Enum):
FOO = torch.tensor(0)
BAR = torch.tensor(1)
make_global(TensorEnum)
def unsupported_enum_types(a: TensorEnum):
return a.name
@ -59,12 +55,12 @@ class TestEnum(JitTestCase):
torch.jit.script(unsupported_enum_types)
def test_enum_comp(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def enum_comp(x: Color, y: Color) -> bool:
return x == y
@ -75,8 +71,6 @@ class TestEnum(JitTestCase):
self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
def test_enum_comp_diff_classes(self):
global Foo, Bar
class Foo(Enum):
ITEM1 = 1
ITEM2 = 2
@ -85,6 +79,8 @@ class TestEnum(JitTestCase):
ITEM1 = 1
ITEM2 = 2
make_global(Foo, Bar)
@torch.jit.script
def enum_comp(x: Foo) -> bool:
return x == Bar.ITEM1
@ -98,12 +94,12 @@ class TestEnum(JitTestCase):
self.assertEqual(enum_comp(Foo.ITEM1), False)
def test_heterogenous_value_type_enum_error(self):
global Color
class Color(Enum):
RED = 1
GREEN = "green"
make_global(Color)
def enum_comp(x: Color, y: Color) -> bool:
return x == y
@ -111,12 +107,12 @@ class TestEnum(JitTestCase):
torch.jit.script(enum_comp)
def test_enum_name(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def enum_name(x: Color) -> str:
return x.name
@ -131,12 +127,12 @@ class TestEnum(JitTestCase):
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
def test_enum_value(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def enum_value(x: Color) -> int:
return x.value
@ -151,12 +147,12 @@ class TestEnum(JitTestCase):
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
def test_enum_as_const(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def enum_const(x: Color) -> bool:
return x == Color.RED
@ -171,12 +167,12 @@ class TestEnum(JitTestCase):
self.assertEqual(enum_const(Color.GREEN), False)
def test_non_existent_enum_value(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
def enum_const(x: Color) -> bool:
if x == Color.PURPLE:
return True
@ -187,12 +183,12 @@ class TestEnum(JitTestCase):
torch.jit.script(enum_const)
def test_enum_ivalue_type(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def is_color_enum(x: Any):
return isinstance(x, Color)
@ -207,8 +203,6 @@ class TestEnum(JitTestCase):
self.assertEqual(is_color_enum(1), False)
def test_closed_over_enum_constant(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@ -240,8 +234,6 @@ class TestEnum(JitTestCase):
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
def test_enum_as_module_attribute(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@ -268,8 +260,6 @@ class TestEnum(JitTestCase):
self.assertEqual(scripted(), Color.RED.value)
def test_string_enum_as_module_attribute(self):
global Color
class Color(Enum):
RED = "red"
GREEN = "green"
@ -282,18 +272,19 @@ class TestEnum(JitTestCase):
def forward(self):
return (self.e.name, self.e.value)
make_global(Color)
m = TestModule(Color.RED)
scripted = torch.jit.script(m)
self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
def test_enum_return(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
make_global(Color)
@torch.jit.script
def return_enum(cond: bool):
if cond:
@ -305,8 +296,6 @@ class TestEnum(JitTestCase):
self.assertEqual(return_enum(False), Color.GREEN)
def test_enum_module_return(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@ -319,6 +308,7 @@ class TestEnum(JitTestCase):
def forward(self):
return self.e
make_global(Color)
m = TestModule(Color.RED)
scripted = torch.jit.script(m)
@ -333,8 +323,6 @@ class TestEnum(JitTestCase):
def test_enum_iterate(self):
global Color
class Color(Enum):
RED = 1
GREEN = 2
@ -347,6 +335,7 @@ class TestEnum(JitTestCase):
res.append(e.value)
return res
make_global(Color)
scripted = torch.jit.script(iterate_enum)
FileCheck() \

View File

@ -4,7 +4,7 @@ import sys
from typing import Any, List
import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable
@ -29,8 +29,6 @@ class TestWith(JitTestCase):
Check that with statements that use the 'as' keyword to bind expressions
to targets work as expected.
"""
global Context
@torch.jit.script
class Context(object):
"""
@ -50,6 +48,8 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
make_global(Context)
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
@ -185,8 +185,6 @@ class TestWith(JitTestCase):
Check that with statements that do not use the 'as' keyword to bind expressions
to targets work as expected.
"""
global Context
@torch.jit.script
class Context(object):
"""
@ -206,6 +204,8 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
make_global(Context)
def test_basic(x: torch.Tensor) -> torch.Tensor:
"""Basic test with one with-statement."""
@ -341,8 +341,6 @@ class TestWith(JitTestCase):
Check that exceptions thrown in the bodies of with-statements are
handled correctly.
"""
global Context
@torch.jit.script
class Context(object):
"""
@ -362,6 +360,8 @@ class TestWith(JitTestCase):
def __exit__(self, type: Any, value: Any, tb: Any):
self.count.sub_(0.3)
make_global(Context)
@torch.jit.script
def method_that_raises() -> torch.Tensor:
raise Exception("raised exception")

View File

@ -65,7 +65,7 @@ from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WIT
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
_trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, \
_trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, make_global, \
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
RUN_CUDA
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \
@ -6609,8 +6609,6 @@ a")
.check("in foo").check("in baz").run(str(cm.exception))
def test_error_stacktrace_interface(self):
global IFace
@torch.jit.script
def baz(c, b):
return c + b
@ -6634,6 +6632,8 @@ a")
# type: (Tensor, Tensor) -> Tensor
pass
make_global(IFace)
@torch.jit.script
def as_interface(x):
# type: (IFace) -> IFace

View File

@ -241,7 +241,9 @@ def get_annotation_str(annotation):
elif isinstance(annotation, ast.Attribute):
return '.'.join([get_annotation_str(annotation.value), annotation.attr])
elif isinstance(annotation, ast.Subscript):
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(annotation.slice.value)}]" # type: ignore
# In Python3.9+ subscript indicies are not wrapped in ast.Index
subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
elif isinstance(annotation, ast.Tuple):
return ','.join([get_annotation_str(elt) for elt in annotation.elts])
elif isinstance(annotation, ast.Constant) or isinstance(annotation, ast.NameConstant):