mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/clee2000
358 lines
9.8 KiB
Python
358 lines
9.8 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
from enum import Enum
|
|
from typing import Any, List
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
|
|
|
|
class TestEnum(JitTestCase):
|
|
def test_enum_value_types(self):
|
|
class IntEnum(Enum):
|
|
FOO = 1
|
|
BAR = 2
|
|
|
|
class FloatEnum(Enum):
|
|
FOO = 1.2
|
|
BAR = 2.3
|
|
|
|
class StringEnum(Enum):
|
|
FOO = "foo as in foo bar"
|
|
BAR = "bar as in foo bar"
|
|
|
|
make_global(IntEnum, FloatEnum, StringEnum)
|
|
|
|
@torch.jit.script
|
|
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
|
|
return (a.name, b.name, c.name)
|
|
|
|
FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
|
|
str(supported_enum_types.graph)
|
|
)
|
|
|
|
class TensorEnum(Enum):
|
|
FOO = torch.tensor(0)
|
|
BAR = torch.tensor(1)
|
|
|
|
make_global(TensorEnum)
|
|
|
|
def unsupported_enum_types(a: TensorEnum):
|
|
return a.name
|
|
|
|
# TODO: rewrite code so that the highlight is not empty.
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
|
|
):
|
|
torch.jit.script(unsupported_enum_types)
|
|
|
|
def test_enum_comp(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def enum_comp(x: Color, y: Color) -> bool:
|
|
return x == y
|
|
|
|
FileCheck().check("aten::eq").run(str(enum_comp.graph))
|
|
|
|
self.assertEqual(enum_comp(Color.RED, Color.RED), True)
|
|
self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
|
|
|
|
def test_enum_comp_diff_classes(self):
|
|
class Foo(Enum):
|
|
ITEM1 = 1
|
|
ITEM2 = 2
|
|
|
|
class Bar(Enum):
|
|
ITEM1 = 1
|
|
ITEM2 = 2
|
|
|
|
make_global(Foo, Bar)
|
|
|
|
@torch.jit.script
|
|
def enum_comp(x: Foo) -> bool:
|
|
return x == Bar.ITEM1
|
|
|
|
FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
|
|
"aten::eq"
|
|
).run(str(enum_comp.graph))
|
|
|
|
self.assertEqual(enum_comp(Foo.ITEM1), False)
|
|
|
|
def test_heterogenous_value_type_enum_error(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = "green"
|
|
|
|
make_global(Color)
|
|
|
|
def enum_comp(x: Color, y: Color) -> bool:
|
|
return x == y
|
|
|
|
# TODO: rewrite code so that the highlight is not empty.
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "Could not unify type list", ""
|
|
):
|
|
torch.jit.script(enum_comp)
|
|
|
|
def test_enum_name(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def enum_name(x: Color) -> str:
|
|
return x.name
|
|
|
|
FileCheck().check("Color").check_next("prim::EnumName").check_next(
|
|
"return"
|
|
).run(str(enum_name.graph))
|
|
|
|
self.assertEqual(enum_name(Color.RED), Color.RED.name)
|
|
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
|
|
|
|
def test_enum_value(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def enum_value(x: Color) -> int:
|
|
return x.value
|
|
|
|
FileCheck().check("Color").check_next("prim::EnumValue").check_next(
|
|
"return"
|
|
).run(str(enum_value.graph))
|
|
|
|
self.assertEqual(enum_value(Color.RED), Color.RED.value)
|
|
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
|
|
|
|
def test_enum_as_const(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def enum_const(x: Color) -> bool:
|
|
return x == Color.RED
|
|
|
|
FileCheck().check(
|
|
"prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
|
|
).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
|
|
|
|
self.assertEqual(enum_const(Color.RED), True)
|
|
self.assertEqual(enum_const(Color.GREEN), False)
|
|
|
|
def test_non_existent_enum_value(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
def enum_const(x: Color) -> bool:
|
|
if x == Color.PURPLE:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
|
|
):
|
|
torch.jit.script(enum_const)
|
|
|
|
def test_enum_ivalue_type(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def is_color_enum(x: Any):
|
|
return isinstance(x, Color)
|
|
|
|
FileCheck().check(
|
|
"prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
|
|
).check_next("return").run(str(is_color_enum.graph))
|
|
|
|
self.assertEqual(is_color_enum(Color.RED), True)
|
|
self.assertEqual(is_color_enum(Color.GREEN), True)
|
|
self.assertEqual(is_color_enum(1), False)
|
|
|
|
def test_closed_over_enum_constant(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
a = Color
|
|
|
|
@torch.jit.script
|
|
def closed_over_aliased_type():
|
|
return a.RED.value
|
|
|
|
FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
|
|
"return"
|
|
).run(str(closed_over_aliased_type.graph))
|
|
|
|
self.assertEqual(closed_over_aliased_type(), Color.RED.value)
|
|
|
|
b = Color.RED
|
|
|
|
@torch.jit.script
|
|
def closed_over_aliased_value():
|
|
return b.value
|
|
|
|
FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
|
|
"return"
|
|
).run(str(closed_over_aliased_value.graph))
|
|
|
|
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
|
|
|
|
def test_enum_as_module_attribute(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, e: Color):
|
|
super().__init__()
|
|
self.e = e
|
|
|
|
def forward(self):
|
|
return self.e.value
|
|
|
|
m = TestModule(Color.RED)
|
|
scripted = torch.jit.script(m)
|
|
|
|
FileCheck().check("TestModule").check_next("Color").check_same(
|
|
'prim::GetAttr[name="e"]'
|
|
).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
|
|
|
|
self.assertEqual(scripted(), Color.RED.value)
|
|
|
|
def test_string_enum_as_module_attribute(self):
|
|
class Color(Enum):
|
|
RED = "red"
|
|
GREEN = "green"
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, e: Color):
|
|
super().__init__()
|
|
self.e = e
|
|
|
|
def forward(self):
|
|
return (self.e.name, self.e.value)
|
|
|
|
make_global(Color)
|
|
m = TestModule(Color.RED)
|
|
scripted = torch.jit.script(m)
|
|
|
|
self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
|
|
|
|
def test_enum_return(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
@torch.jit.script
|
|
def return_enum(cond: bool):
|
|
if cond:
|
|
return Color.RED
|
|
else:
|
|
return Color.GREEN
|
|
|
|
self.assertEqual(return_enum(True), Color.RED)
|
|
self.assertEqual(return_enum(False), Color.GREEN)
|
|
|
|
def test_enum_module_return(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, e: Color):
|
|
super().__init__()
|
|
self.e = e
|
|
|
|
def forward(self):
|
|
return self.e
|
|
|
|
make_global(Color)
|
|
m = TestModule(Color.RED)
|
|
scripted = torch.jit.script(m)
|
|
|
|
FileCheck().check("TestModule").check_next("Color").check_same(
|
|
'prim::GetAttr[name="e"]'
|
|
).check_next("return").run(str(scripted.graph))
|
|
|
|
self.assertEqual(scripted(), Color.RED)
|
|
|
|
def test_enum_iterate(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
BLUE = 3
|
|
|
|
def iterate_enum(x: Color):
|
|
res: List[int] = []
|
|
for e in Color:
|
|
if e != x:
|
|
res.append(e.value)
|
|
return res
|
|
|
|
make_global(Color)
|
|
scripted = torch.jit.script(iterate_enum)
|
|
|
|
FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
|
|
"Color.RED"
|
|
).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
|
|
|
|
# PURPLE always appears last because we follow Python's Enum definition order.
|
|
self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
|
|
self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value])
|
|
|
|
# Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
|
|
def test_enum_explicit_script(self):
|
|
@torch.jit.script
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
torch.jit.script(Color)
|
|
|
|
# Regression test for https://github.com/pytorch/pytorch/issues/108933
|
|
def test_typed_enum(self):
|
|
class Color(int, Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
@torch.jit.script
|
|
def is_red(x: Color) -> bool:
|
|
return x == Color.RED
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|