mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix more JIT tests under Python-3.9 (#51182)
Summary: Mostly replace `global Foo` with `make_global(Foo)` The only real fix is generating Subscript annotation, which is a follow up from https://github.com/pytorch/pytorch/pull/48676 Fixes https://github.com/pytorch/pytorch/issues/49617 Pull Request resolved: https://github.com/pytorch/pytorch/pull/51182 Reviewed By: gmagogsfm Differential Revision: D26095244 Pulled By: malfet fbshipit-source-id: 0e043d9a2cf43fff71dfbb341f708cd7af87c39a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9b6d463704
commit
00adc7b07f
@ -11,7 +11,7 @@ from typing import Any
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE
|
||||
from typing import List, Tuple, Iterable, Optional, Dict
|
||||
@ -143,12 +143,12 @@ class TestClassType(JitTestCase):
|
||||
self.attr = x
|
||||
|
||||
def test_class_type_as_param(self):
|
||||
global FooTest # see [local resolution in python]
|
||||
|
||||
class FooTest(object): # noqa: B903
|
||||
def __init__(self, x):
|
||||
self.attr = x
|
||||
|
||||
make_global(FooTest) # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
def fn(foo: FooTest) -> torch.Tensor:
|
||||
return foo.attr
|
||||
@ -279,13 +279,13 @@ class TestClassType(JitTestCase):
|
||||
self.assertEqual(2 * input, output)
|
||||
|
||||
def test_python_interop(self):
|
||||
global Foo # see [local resolution in python]
|
||||
|
||||
class Foo(object): # noqa: B903
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
def use_foo(foo: Foo) -> Foo:
|
||||
return foo
|
||||
@ -305,13 +305,13 @@ class TestClassType(JitTestCase):
|
||||
self.assertEqual(y, f2.y)
|
||||
|
||||
def test_class_specialization(self):
|
||||
global Foo # see [local resolution in python]
|
||||
|
||||
class Foo(object): # noqa: B903
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor:
|
||||
a, b = tup
|
||||
return foo.x + foo2.y + a.x + b.y
|
||||
@ -329,8 +329,6 @@ class TestClassType(JitTestCase):
|
||||
FileCheck().check_count("prim::GetAttr", 4).run(graphstr)
|
||||
|
||||
def test_class_sorting(self):
|
||||
global Foo # see [local resolution in python]
|
||||
|
||||
class Foo(object): # noqa: B903
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
@ -342,6 +340,8 @@ class TestClassType(JitTestCase):
|
||||
def getVal(self):
|
||||
return self.x
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]:
|
||||
li_sorted = sorted(li)
|
||||
ret_sorted = torch.jit.annotate(List[int], [])
|
||||
@ -500,8 +500,6 @@ class TestClassType(JitTestCase):
|
||||
self.assertEqual(3 * input, output)
|
||||
|
||||
def test_interface(self):
|
||||
global Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2
|
||||
|
||||
@torch.jit.script
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
@ -571,6 +569,8 @@ class TestClassType(JitTestCase):
|
||||
def two(self, x: int) -> int:
|
||||
return 3
|
||||
|
||||
make_global(Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2)
|
||||
|
||||
def use_them(x):
|
||||
a = Foo()
|
||||
b = Bar()
|
||||
@ -652,8 +652,6 @@ class TestClassType(JitTestCase):
|
||||
# NamedTuple inheritance errors
|
||||
|
||||
def test_overloaded_fn(self):
|
||||
global Foo, MyClass # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
class Foo(object):
|
||||
def __init__(self, x):
|
||||
@ -673,6 +671,8 @@ class TestClassType(JitTestCase):
|
||||
a = Foo(torch.ones([3, 3]))
|
||||
return len(a), -a * torch.zeros([3, 3])
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
self.checkScript(test_overload, ())
|
||||
# unary ops tested above
|
||||
|
||||
@ -737,6 +737,8 @@ class TestClassType(JitTestCase):
|
||||
return self.x * val * 3
|
||||
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
def add():
|
||||
return MyClass(4) + 3
|
||||
def sub(): # noqa: E306
|
||||
@ -787,8 +789,6 @@ class TestClassType(JitTestCase):
|
||||
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
|
||||
|
||||
def test_cast_overloads(self):
|
||||
global Foo # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
class Foo(object):
|
||||
def __init__(self, val: float) -> None:
|
||||
@ -806,6 +806,8 @@ class TestClassType(JitTestCase):
|
||||
def __str__(self):
|
||||
return str(self.val)
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
def test(foo: Foo) -> Tuple[int, float, bool]:
|
||||
if foo:
|
||||
pass
|
||||
@ -914,8 +916,6 @@ class TestClassType(JitTestCase):
|
||||
self.assertEqual(m.w, m_loaded.w)
|
||||
|
||||
def test_py_class_to_ivalue_missing_attribute(self):
|
||||
global Foo # see [local resolution in python]
|
||||
|
||||
class Foo(object):
|
||||
i : int
|
||||
f : float
|
||||
@ -924,6 +924,8 @@ class TestClassType(JitTestCase):
|
||||
self.i = i
|
||||
self.f = f
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
def test_fn(x : Foo) -> float:
|
||||
return x.i + x.f
|
||||
@ -1132,8 +1134,6 @@ class TestClassType(JitTestCase):
|
||||
"""
|
||||
Test static methods on class types.
|
||||
"""
|
||||
global ClassWithStaticMethod
|
||||
|
||||
@torch.jit.script
|
||||
class ClassWithStaticMethod:
|
||||
def __init__(self, a: int, b: int):
|
||||
@ -1164,14 +1164,14 @@ class TestClassType(JitTestCase):
|
||||
def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
|
||||
return ClassWithStaticMethod.create_from(a, b)
|
||||
|
||||
make_global(ClassWithStaticMethod)
|
||||
|
||||
self.checkScript(test_function, (1, 2))
|
||||
|
||||
def test_classmethod(self):
|
||||
"""
|
||||
Test classmethods on class types.
|
||||
"""
|
||||
global ClassWithClassMethod
|
||||
|
||||
@torch.jit.script
|
||||
class ClassWithClassMethod:
|
||||
def __init__(self, a: int):
|
||||
@ -1184,6 +1184,8 @@ class TestClassType(JitTestCase):
|
||||
def create(cls, a: int) -> 'ClassWithClassMethod':
|
||||
return cls(a)
|
||||
|
||||
make_global(ClassWithClassMethod)
|
||||
|
||||
def test_function(a: int) -> 'ClassWithClassMethod':
|
||||
x = ClassWithClassMethod(a)
|
||||
# Support calling classmethod with an instance
|
||||
|
@ -9,7 +9,7 @@ from typing import Any, List
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
@ -18,24 +18,20 @@ if __name__ == '__main__':
|
||||
|
||||
class TestEnum(JitTestCase):
|
||||
def test_enum_value_types(self):
|
||||
global IntEnum
|
||||
|
||||
class IntEnum(Enum):
|
||||
FOO = 1
|
||||
BAR = 2
|
||||
|
||||
global FloatEnum
|
||||
|
||||
class FloatEnum(Enum):
|
||||
FOO = 1.2
|
||||
BAR = 2.3
|
||||
|
||||
global StringEnum
|
||||
|
||||
class StringEnum(Enum):
|
||||
FOO = "foo as in foo bar"
|
||||
BAR = "bar as in foo bar"
|
||||
|
||||
make_global(IntEnum, FloatEnum, StringEnum)
|
||||
|
||||
@torch.jit.script
|
||||
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
|
||||
return (a.name, b.name, c.name)
|
||||
@ -46,12 +42,12 @@ class TestEnum(JitTestCase):
|
||||
.check("StringEnum") \
|
||||
.run(str(supported_enum_types.graph))
|
||||
|
||||
global TensorEnum
|
||||
|
||||
class TensorEnum(Enum):
|
||||
FOO = torch.tensor(0)
|
||||
BAR = torch.tensor(1)
|
||||
|
||||
make_global(TensorEnum)
|
||||
|
||||
def unsupported_enum_types(a: TensorEnum):
|
||||
return a.name
|
||||
|
||||
@ -59,12 +55,12 @@ class TestEnum(JitTestCase):
|
||||
torch.jit.script(unsupported_enum_types)
|
||||
|
||||
def test_enum_comp(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def enum_comp(x: Color, y: Color) -> bool:
|
||||
return x == y
|
||||
@ -75,8 +71,6 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
|
||||
|
||||
def test_enum_comp_diff_classes(self):
|
||||
global Foo, Bar
|
||||
|
||||
class Foo(Enum):
|
||||
ITEM1 = 1
|
||||
ITEM2 = 2
|
||||
@ -85,6 +79,8 @@ class TestEnum(JitTestCase):
|
||||
ITEM1 = 1
|
||||
ITEM2 = 2
|
||||
|
||||
make_global(Foo, Bar)
|
||||
|
||||
@torch.jit.script
|
||||
def enum_comp(x: Foo) -> bool:
|
||||
return x == Bar.ITEM1
|
||||
@ -98,12 +94,12 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(enum_comp(Foo.ITEM1), False)
|
||||
|
||||
def test_heterogenous_value_type_enum_error(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = "green"
|
||||
|
||||
make_global(Color)
|
||||
|
||||
def enum_comp(x: Color, y: Color) -> bool:
|
||||
return x == y
|
||||
|
||||
@ -111,12 +107,12 @@ class TestEnum(JitTestCase):
|
||||
torch.jit.script(enum_comp)
|
||||
|
||||
def test_enum_name(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def enum_name(x: Color) -> str:
|
||||
return x.name
|
||||
@ -131,12 +127,12 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
|
||||
|
||||
def test_enum_value(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def enum_value(x: Color) -> int:
|
||||
return x.value
|
||||
@ -151,12 +147,12 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
|
||||
|
||||
def test_enum_as_const(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def enum_const(x: Color) -> bool:
|
||||
return x == Color.RED
|
||||
@ -171,12 +167,12 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(enum_const(Color.GREEN), False)
|
||||
|
||||
def test_non_existent_enum_value(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
def enum_const(x: Color) -> bool:
|
||||
if x == Color.PURPLE:
|
||||
return True
|
||||
@ -187,12 +183,12 @@ class TestEnum(JitTestCase):
|
||||
torch.jit.script(enum_const)
|
||||
|
||||
def test_enum_ivalue_type(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def is_color_enum(x: Any):
|
||||
return isinstance(x, Color)
|
||||
@ -207,8 +203,6 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(is_color_enum(1), False)
|
||||
|
||||
def test_closed_over_enum_constant(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
@ -240,8 +234,6 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
|
||||
|
||||
def test_enum_as_module_attribute(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
@ -268,8 +260,6 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(scripted(), Color.RED.value)
|
||||
|
||||
def test_string_enum_as_module_attribute(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
@ -282,18 +272,19 @@ class TestEnum(JitTestCase):
|
||||
def forward(self):
|
||||
return (self.e.name, self.e.value)
|
||||
|
||||
make_global(Color)
|
||||
m = TestModule(Color.RED)
|
||||
scripted = torch.jit.script(m)
|
||||
|
||||
self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
|
||||
|
||||
def test_enum_return(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
|
||||
make_global(Color)
|
||||
|
||||
@torch.jit.script
|
||||
def return_enum(cond: bool):
|
||||
if cond:
|
||||
@ -305,8 +296,6 @@ class TestEnum(JitTestCase):
|
||||
self.assertEqual(return_enum(False), Color.GREEN)
|
||||
|
||||
def test_enum_module_return(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
@ -319,6 +308,7 @@ class TestEnum(JitTestCase):
|
||||
def forward(self):
|
||||
return self.e
|
||||
|
||||
make_global(Color)
|
||||
m = TestModule(Color.RED)
|
||||
scripted = torch.jit.script(m)
|
||||
|
||||
@ -333,8 +323,6 @@ class TestEnum(JitTestCase):
|
||||
|
||||
|
||||
def test_enum_iterate(self):
|
||||
global Color
|
||||
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
@ -347,6 +335,7 @@ class TestEnum(JitTestCase):
|
||||
res.append(e.value)
|
||||
return res
|
||||
|
||||
make_global(Color)
|
||||
scripted = torch.jit.script(iterate_enum)
|
||||
|
||||
FileCheck() \
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -29,8 +29,6 @@ class TestWith(JitTestCase):
|
||||
Check that with statements that use the 'as' keyword to bind expressions
|
||||
to targets work as expected.
|
||||
"""
|
||||
global Context
|
||||
|
||||
@torch.jit.script
|
||||
class Context(object):
|
||||
"""
|
||||
@ -50,6 +48,8 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: Any, tb: Any):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
make_global(Context)
|
||||
|
||||
def test_basic(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Basic test with one with-statement."""
|
||||
|
||||
@ -185,8 +185,6 @@ class TestWith(JitTestCase):
|
||||
Check that with statements that do not use the 'as' keyword to bind expressions
|
||||
to targets work as expected.
|
||||
"""
|
||||
global Context
|
||||
|
||||
@torch.jit.script
|
||||
class Context(object):
|
||||
"""
|
||||
@ -206,6 +204,8 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: Any, tb: Any):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
make_global(Context)
|
||||
|
||||
def test_basic(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Basic test with one with-statement."""
|
||||
|
||||
@ -341,8 +341,6 @@ class TestWith(JitTestCase):
|
||||
Check that exceptions thrown in the bodies of with-statements are
|
||||
handled correctly.
|
||||
"""
|
||||
global Context
|
||||
|
||||
@torch.jit.script
|
||||
class Context(object):
|
||||
"""
|
||||
@ -362,6 +360,8 @@ class TestWith(JitTestCase):
|
||||
def __exit__(self, type: Any, value: Any, tb: Any):
|
||||
self.count.sub_(0.3)
|
||||
|
||||
make_global(Context)
|
||||
|
||||
@torch.jit.script
|
||||
def method_that_raises() -> torch.Tensor:
|
||||
raise Exception("raised exception")
|
||||
|
@ -65,7 +65,7 @@ from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WIT
|
||||
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \
|
||||
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs
|
||||
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
|
||||
_trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, \
|
||||
_trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, make_global, \
|
||||
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
|
||||
RUN_CUDA
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \
|
||||
@ -6609,8 +6609,6 @@ a")
|
||||
.check("in foo").check("in baz").run(str(cm.exception))
|
||||
|
||||
def test_error_stacktrace_interface(self):
|
||||
global IFace
|
||||
|
||||
@torch.jit.script
|
||||
def baz(c, b):
|
||||
return c + b
|
||||
@ -6634,6 +6632,8 @@ a")
|
||||
# type: (Tensor, Tensor) -> Tensor
|
||||
pass
|
||||
|
||||
make_global(IFace)
|
||||
|
||||
@torch.jit.script
|
||||
def as_interface(x):
|
||||
# type: (IFace) -> IFace
|
||||
|
@ -241,7 +241,9 @@ def get_annotation_str(annotation):
|
||||
elif isinstance(annotation, ast.Attribute):
|
||||
return '.'.join([get_annotation_str(annotation.value), annotation.attr])
|
||||
elif isinstance(annotation, ast.Subscript):
|
||||
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(annotation.slice.value)}]" # type: ignore
|
||||
# In Python3.9+ subscript indicies are not wrapped in ast.Index
|
||||
subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore
|
||||
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
|
||||
elif isinstance(annotation, ast.Tuple):
|
||||
return ','.join([get_annotation_str(elt) for elt in annotation.elts])
|
||||
elif isinstance(annotation, ast.Constant) or isinstance(annotation, ast.NameConstant):
|
||||
|
Reference in New Issue
Block a user