mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
1066 lines
33 KiB
Python
1066 lines
33 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import io
|
|
import os
|
|
import sys
|
|
from enum import Enum
|
|
from textwrap import dedent
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestUnion(JitTestCase):
|
|
"""
|
|
This class tests the functionality of `Union`.
|
|
|
|
Note: It's important to be able to refine the type of a `Union` to
|
|
one of its internal types. Currently, there are differences in the
|
|
way Python expects `isinstance` checks and the way TorchScript
|
|
expects `isinstance` checks. This means that we can't use
|
|
`checkScript` in our test cases because either the eager mode or the
|
|
script mode wouldn't run! So, some test cases have separate but
|
|
equivalent functions to emulate `checkScript`.
|
|
"""
|
|
|
|
def test_check_union_annotation(self):
|
|
def test_func(a: Union[int, float], b: Optional[int]):
|
|
return 0
|
|
|
|
scripted_func = torch.jit.script(test_func)
|
|
graph_rep = str(scripted_func.graph)
|
|
code_rep = str(scripted_func.code)
|
|
# TS graph IR for Union should be annotated as Union()
|
|
FileCheck().check("Union(").check("int?").run(graph_rep)
|
|
# Serialized code for Union should be annotated as Union[]
|
|
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
|
|
self.checkScript(test_func, (5, 6))
|
|
# this shouldn't error out
|
|
torch._C.parse_ir(str(scripted_func.graph))
|
|
|
|
def test_union_with_scalar_values(self):
|
|
def fn(x: Union[int, float]) -> str:
|
|
return "foo"
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, (1.0,))
|
|
|
|
scripted = torch.jit.script(fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[float, int\] but "
|
|
"instead found type str",
|
|
):
|
|
scripted("1")
|
|
|
|
def test_union_with_collections(self):
|
|
def fn(x: Union[Dict[str, int], List[int]]) -> str:
|
|
return "foo"
|
|
|
|
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
|
|
scripted = torch.jit.script(fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[List\[int\], Dict\[str, "
|
|
r"int\]\] but instead found type "
|
|
r"Dict\[str, str\]",
|
|
):
|
|
scripted({"foo": "bar", "baz": "qux"})
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[List\[int\], Dict\[str, "
|
|
r"int\]\] but instead found type "
|
|
r"List\[str\]",
|
|
):
|
|
scripted(["foo", "bar", "baz"])
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[List\[int\], Dict\[str, "
|
|
r"int\]\] but instead found type "
|
|
"str",
|
|
):
|
|
scripted("1")
|
|
|
|
def test_union_with_enum(self):
|
|
class Color(Enum):
|
|
RED = 1
|
|
GREEN = 2
|
|
|
|
make_global(Color)
|
|
|
|
def fn(x: Union[str, Color]) -> str:
|
|
return "foo"
|
|
|
|
self.checkScript(fn, (Color.RED,))
|
|
self.checkScript(fn, ("red",))
|
|
|
|
scripted = torch.jit.script(fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[__torch__.jit.test_union."
|
|
r"Color, str\] but instead found "
|
|
"type int",
|
|
):
|
|
scripted(1)
|
|
|
|
def test_union_in_class_constructor(self):
|
|
@torch.jit.script # noqa: B903
|
|
class A: # noqa: B903
|
|
def __init__(self, x: Union[int, str]) -> None:
|
|
self.x = x
|
|
|
|
def fn(x: Union[str, int]) -> A:
|
|
return A(x)
|
|
|
|
self.assertEqual(fn("foo").x, "foo")
|
|
self.assertEqual(fn(1).x, 1)
|
|
|
|
scripted = torch.jit.script(fn)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected a member of"
|
|
r" Union\[int, str\] but instead "
|
|
r"found type List\[str\]",
|
|
):
|
|
scripted(["foo", "bar", "baz"])
|
|
|
|
def test_union_return_type(self):
|
|
def fn(x: int) -> Union[int, str]:
|
|
return "foo"
|
|
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_union_as_annotation(self):
|
|
def fn() -> Union[int, str]:
|
|
x: Union[int, str] = "foo"
|
|
return x
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_as_annotation_in_typed_container(self):
|
|
def fn() -> None:
|
|
l: List[Union[int, str]] = []
|
|
u1: Union[int, str] = "foo"
|
|
u2: Union[int, str] = 1
|
|
l.append(u1)
|
|
l.append(u2)
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_as_annotation_py2(self):
|
|
def fn():
|
|
# type: () -> Union[int, str]
|
|
x: Union[int, str] = "foo"
|
|
return x
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_as_internal_tuple_type(self):
|
|
def fn():
|
|
t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
|
|
return t
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_variable_can_be_reassigned(self):
|
|
@torch.jit.script
|
|
def aux1(i: int):
|
|
return int(i**2)
|
|
|
|
@torch.jit.script
|
|
def aux2(s: str):
|
|
return s + s
|
|
|
|
def fn() -> Union[int, str]:
|
|
x: Union[int, str] = "foo"
|
|
i: int = 1
|
|
x = i
|
|
y: int = aux1(x)
|
|
z: str = aux2(str(y))
|
|
x = z
|
|
return x
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_does_not_replace_existing_annotated_type(self):
|
|
def fn():
|
|
x: List[int] = [1, 2, 3]
|
|
x.append("foo")
|
|
return x
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
|
|
scripted = torch.jit.script(fn)
|
|
scripted()
|
|
|
|
def test_union_does_not_replace_existing_annotated_type_union(self):
|
|
def fn():
|
|
x: List[Union[int, str]] = [1, "foo", 3]
|
|
x.append(2.0)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
|
|
scripted = torch.jit.script(fn)
|
|
scripted()
|
|
|
|
def test_union_does_not_replace_existing_annotated_type_empty_container(self):
|
|
def fn():
|
|
x: List[int] = []
|
|
x.append("foo")
|
|
return x
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
|
|
scripted = torch.jit.script(fn)
|
|
scripted()
|
|
|
|
def test_unions_of_unions_are_flattened(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[Union[int, str], float]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : Union(float, int, str)").run(s)
|
|
|
|
def test_unions_of_a_single_argument_vanish(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[int]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : int").run(s)
|
|
|
|
def test_union_redundant_arguments_are_skipped(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[int, str, int]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : Union(int, str)").run(s)
|
|
|
|
def test_union_redundant_arguments_are_skipped_optional(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : Union(float, int, NoneType)").run(s)
|
|
|
|
def test_union_redundant_arguments_are_skipped_subtyping(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : Union((int?, int), str)").run(s)
|
|
|
|
def test_union_redundant_arguments_are_skipped_container(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[List[str], List[float], List[str]]) -> str:
|
|
return "foo"
|
|
|
|
s = fn.graph
|
|
|
|
FileCheck().check("x : Union(float[], str[])").run(s)
|
|
|
|
def test_union_argument_order_is_ignored(self):
|
|
@torch.jit.script
|
|
def fn1(x: Union[int, str]) -> str:
|
|
return "foo"
|
|
|
|
@torch.jit.script
|
|
def fn2(x: Union[str, int]) -> str:
|
|
return "foo"
|
|
|
|
for s in (fn1.graph, fn2.graph):
|
|
FileCheck().check("x : Union(int, str)").run(s)
|
|
|
|
def test_union_argument_order_is_ignored_container(self):
|
|
@torch.jit.script
|
|
def fn1(x: Union[List[str], List[int]]) -> str:
|
|
return "foo"
|
|
|
|
@torch.jit.script
|
|
def fn2(x: Union[List[int], List[str]]) -> str:
|
|
return "foo"
|
|
|
|
for s in (fn1.graph, fn2.graph):
|
|
FileCheck().check("x : Union(int[], str[])").run(s)
|
|
|
|
def test_union_T_None_is_equivalent_to_optional_T(self):
|
|
@torch.jit.script
|
|
def inner(x: Union[int, None]) -> int:
|
|
if x is not None:
|
|
return x
|
|
else:
|
|
return 5
|
|
|
|
@torch.jit.script
|
|
def fn1() -> int:
|
|
a: Optional[int] = 5
|
|
b: Optional[int] = None
|
|
a_ = inner(a)
|
|
b_ = inner(b)
|
|
return a_ + b_
|
|
|
|
self.assertEqual(fn1(), 10)
|
|
|
|
@torch.jit.script
|
|
def inner2(x: Optional[int]) -> int:
|
|
if x is not None:
|
|
return x
|
|
else:
|
|
return 5
|
|
|
|
@torch.jit.script
|
|
def fn2() -> int:
|
|
a: Union[int, None] = 5
|
|
b: Union[int, None] = None
|
|
a_ = inner(a)
|
|
b_ = inner(b)
|
|
return a_ + b_
|
|
|
|
self.assertEqual(fn2(), 10)
|
|
|
|
def test_union_optional_of_union_is_flattened(self):
|
|
@torch.jit.script
|
|
def fn(flag: int) -> Union[str, int, None]:
|
|
y: Union[int, str, None] = "foo"
|
|
if flag == 0:
|
|
x: Optional[Union[int, str]] = y
|
|
elif flag == 1:
|
|
x: Optional[Union[int, str]] = 1
|
|
else:
|
|
x: Optional[Union[int, str]] = None
|
|
return x
|
|
|
|
# Can't use `checkScript` because it will flag the fact that
|
|
# the original code has `Optional[Union[int, str]]` but the
|
|
# saved/loaded code has `Union[int, NoneType, str]` (even
|
|
# though this is exactly what we want)
|
|
self.assertEqual(fn(0), "foo")
|
|
self.assertEqual(fn(1), 1)
|
|
self.assertEqual(fn(2), None)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(fn, buffer)
|
|
buffer = io.BytesIO(buffer.getvalue())
|
|
l = torch.jit.load(buffer)
|
|
|
|
s = l.code
|
|
|
|
FileCheck().check("Union[int, NoneType, str]").check(
|
|
"Union[int, NoneType, str]"
|
|
).run(s)
|
|
|
|
def test_union_subclasses_larger_union(self):
|
|
def fn() -> Union[int, str, torch.Tensor]:
|
|
x: Union[int, str] = "foo"
|
|
return x
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
# TODO: We would like to eventually support this. The issue is being
|
|
# tracked at https://github.com/pytorch/pytorch/issues/58167
|
|
def test_union_as_dict_key(self):
|
|
def fn():
|
|
x: Dict[Union[int, str], str] = {}
|
|
x["foo"] = "bar"
|
|
x[1] = 2
|
|
return x[1]
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"only int, float, "
|
|
"complex, Tensor, device and string keys "
|
|
"are supported",
|
|
):
|
|
torch.jit.script(fn)
|
|
|
|
def test_union_as_dict_value(self):
|
|
def fn():
|
|
x: Dict[str, Union[int, str]] = {}
|
|
x["foo"] = "bar"
|
|
x["baz"] = 2
|
|
return x["baz"]
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_module_with_union_instance_variable(self):
|
|
class M(torch.nn.Module):
|
|
x: Union[int, str]
|
|
|
|
def __init__(self, x: Union[int, str]):
|
|
super().__init__()
|
|
self.x: Union[int, str] = x
|
|
|
|
def forward(self, y: Union[int, str]):
|
|
self.x = y
|
|
return self.x
|
|
|
|
self.checkModule(
|
|
M(
|
|
2,
|
|
),
|
|
(1,),
|
|
)
|
|
self.checkModule(M("bar"), ("foo",))
|
|
|
|
def test_union_module_with_union_class_variable(self):
|
|
class M(torch.nn.Module):
|
|
x: Union[int, str] = "foo"
|
|
|
|
def __init__(self, y: int):
|
|
super().__init__()
|
|
x = y
|
|
|
|
def forward(self, z: str):
|
|
x = z
|
|
return x
|
|
|
|
self.checkModule(M(1), ("foo",))
|
|
|
|
def test_union_type_refinement(self):
|
|
def fn(x: Union[int, str]) -> str:
|
|
if isinstance(x, str):
|
|
z = x + "bar"
|
|
return x
|
|
else:
|
|
return "baz"
|
|
|
|
self.checkScript(fn, ("foo",))
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_union_type_refinement_union_rhs(self):
|
|
def fn(x: int) -> str:
|
|
if torch.jit.isinstance(x, Union[int, str]):
|
|
return "bar"
|
|
else:
|
|
return "baz"
|
|
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_union_type_refinement_tuple_rhs(self):
|
|
def fn(x: Union[int, float, List[str]]) -> str:
|
|
if isinstance(x, (int, float)):
|
|
if isinstance(x, int):
|
|
return str(x)
|
|
else:
|
|
return "foo"
|
|
else:
|
|
if len(x):
|
|
return x[0]
|
|
else:
|
|
return "bar"
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, (1.0,))
|
|
self.checkScript(fn, (["a", "b", "c"],))
|
|
|
|
def test_union_type_refinement_tuple_rhs_noncontained_type(self):
|
|
def fn(x: Union[int, List[str]]) -> str:
|
|
if isinstance(x, (int, float)):
|
|
y = x + x
|
|
return str(y)
|
|
else:
|
|
if len(x):
|
|
return x[0]
|
|
else:
|
|
return "bar"
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, (["a", "b", "c"],))
|
|
|
|
def test_union_type_refinement_tuple_rhs_union(self):
|
|
@torch.jit.script
|
|
def fn(x: int) -> str:
|
|
if torch.jit.isinstance(x, (Union[int, str], float)):
|
|
y = x + x
|
|
return str(y)
|
|
else:
|
|
return "foo"
|
|
|
|
# TODO: There's currently an unrelated bug in
|
|
# `torch.jit.isinstance` that makes it fail for tuple literals.
|
|
# Posted here: https://github.com/pytorch/pytorch/issues/60095
|
|
# Change `assertEqual` to `checkScript` when the bug is fixed
|
|
self.assertEqual(fn(1), "2")
|
|
|
|
def test_union_type_refinement_statically_false(self):
|
|
@torch.jit.script
|
|
def fn(x: int) -> str:
|
|
if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
|
|
z = x + "foo"
|
|
return z
|
|
else:
|
|
return "bar"
|
|
|
|
s = fn.graph
|
|
|
|
# Check that we don't have any branching statements
|
|
FileCheck().check_not("block0()").check_not("block1()").run(s)
|
|
|
|
def test_union_type_refinement_statically_true(self):
|
|
@torch.jit.script
|
|
def fn(x: Union[List[int], int]) -> Union[List[int], int]:
|
|
if not torch.jit.isinstance(x, (int, List[int])):
|
|
return x
|
|
else:
|
|
l = [1, 2, 3]
|
|
y: Union[List[int], int] = l
|
|
return y
|
|
|
|
s = fn.graph
|
|
|
|
# Check that we don't have any branching statements
|
|
FileCheck().check_not("block0()").check_not("block1()").run(s)
|
|
|
|
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
|
|
def fn(x: Union[List[int], int]) -> int:
|
|
if torch.jit.isinstance(x, (int, float, str)):
|
|
# We should know that `x` is an `int` here
|
|
z = x + 1
|
|
return z
|
|
else:
|
|
return 100
|
|
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_union_type_refinement_partial_static_refinement_union_rhs(self):
|
|
def fn(x: Union[List[int], int]) -> int:
|
|
if torch.jit.isinstance(x, Union[int, float, str]):
|
|
# We should know that `x` is an `int` here
|
|
z = x + 1
|
|
return z
|
|
else:
|
|
return 100
|
|
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_union_type_refinement_internal_declaration(self):
|
|
def fn(flag: bool) -> str:
|
|
x: Union[int, str, None] = None
|
|
if flag:
|
|
y = "foo"
|
|
else:
|
|
y = 1
|
|
if isinstance(x, str):
|
|
return x
|
|
else:
|
|
return "bar"
|
|
|
|
self.checkScript(fn, (True,))
|
|
self.checkScript(fn, (False,))
|
|
|
|
def test_union_branching_with_union_return_and_homogenous_types(self):
|
|
def fn(x: int) -> Union[int, str]:
|
|
if x % 2:
|
|
return "foo"
|
|
else:
|
|
return "bar"
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, (8,))
|
|
|
|
def test_union_branching_does_not_autoinfer_undeclared_union(self):
|
|
def fn(x: int) -> str:
|
|
if x % 2:
|
|
y = "foo"
|
|
else:
|
|
y = x
|
|
if isinstance(y, str):
|
|
return y
|
|
else:
|
|
return "bar"
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"y is set to type str"
|
|
" in the true branch and type int "
|
|
"in the false branch",
|
|
):
|
|
torch.jit.script(fn)
|
|
|
|
def test_union_branching_does_not_widen_existing_inferred_type(self):
|
|
def fn(x: int) -> str:
|
|
y = "foo"
|
|
if x % 2:
|
|
y = "bar"
|
|
else:
|
|
y = x
|
|
if isinstance(y, str):
|
|
return y
|
|
else:
|
|
return "baz"
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"previously had type "
|
|
"str but is now being assigned to a"
|
|
" value of type int",
|
|
):
|
|
torch.jit.script(fn)
|
|
|
|
def test_union_schema_matching_on_internal_type(self):
|
|
def fn(x: Union[List[int], Dict[str, int]]) -> int:
|
|
if torch.jit.isinstance(x, List[int]):
|
|
return x[0]
|
|
else:
|
|
return list(x.values())[0]
|
|
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
|
|
|
|
def test_union_subtractive_refinement(self):
|
|
def fn(x: Union[List[int], int]) -> int:
|
|
if not isinstance(x, int):
|
|
x.append(1)
|
|
return x[0]
|
|
else:
|
|
return x
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
|
|
def test_union_subtractive_refinement_with_container(self):
|
|
def fn(x: Union[List[int], int]) -> int:
|
|
if not torch.jit.isinstance(x, List[int]):
|
|
return x
|
|
else:
|
|
x.append(1)
|
|
return x[0]
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, ([1, 2, 3],))
|
|
|
|
def test_union_memory_aliasing(self):
|
|
def fn():
|
|
x: List[torch.Tensor] = []
|
|
z: List[Optional[List[torch.Tensor]]] = []
|
|
z.append(x)
|
|
x_alias = z[0]
|
|
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
|
|
x_alias.append(torch.tensor(3))
|
|
return x
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_union_serialization_preserves_type_annotations(self):
|
|
# This function will fail after being torch.jit.save'd and
|
|
# torch.jit.load'd if the type annotations aren't preserved
|
|
# for Union during serialization. We need the `Union[str, int]`
|
|
# annotation to make sure that `y` is typed as a Union instead
|
|
# of as a str in one branch and an int in the other
|
|
def fn(x: int) -> str:
|
|
if x % 2:
|
|
y: Union[str, int] = "bar"
|
|
else:
|
|
y: Union[str, int] = x
|
|
if isinstance(y, str):
|
|
return y
|
|
else:
|
|
return "baz"
|
|
|
|
self.checkScript(fn, (1,))
|
|
self.checkScript(fn, (8,))
|
|
|
|
def _assert_passes(self, template: str, ann: str, lhs: str):
|
|
code = template.format(ann=ann, lhs=lhs)
|
|
self.checkScript(code, (), name="fn")
|
|
|
|
def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
|
|
code = template.format(ann=ann, lhs=lhs)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
cu = torch.jit.CompilationUnit(code, _frames_up=1)
|
|
string_frontend = getattr(cu, "fn") # noqa: B009
|
|
|
|
def test_union_with_list_assignment(self):
|
|
template = dedent(
|
|
"""
|
|
def fn():
|
|
x: {ann} = {lhs}
|
|
if torch.jit.isinstance(x, List[torch.Tensor]):
|
|
x.append(torch.tensor(3))
|
|
return x
|
|
"""
|
|
)
|
|
|
|
lhs = {
|
|
"list_literal_empty": "[]",
|
|
"list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
|
|
"list_literal_of_str": '["foo", "bar", "baz"]',
|
|
"list_literal_of_mixed": "[torch.arange(5), 1]",
|
|
"list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
|
|
"list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
|
|
"list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
|
|
}
|
|
|
|
"""
|
|
Union[List[str], List[torch.Tensor]]
|
|
"""
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_literal_empty"],
|
|
"there are multiple possible List type "
|
|
"candidates in the Union annotation",
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_literal_of_tensor"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_literal_of_mixed"],
|
|
"none of those types match the types of the" " given list elements",
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_comprehension_of_tensor"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_comprehension_of_str"],
|
|
)
|
|
|
|
# TODO: Support mixed list comprehensions
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["list_comprehension_of_mixed"],
|
|
"Arguments for call are not valid",
|
|
)
|
|
|
|
"""
|
|
Union[int, torch.Tensor]
|
|
"""
|
|
self._assert_raises(
|
|
template,
|
|
"Union[int, torch.Tensor]",
|
|
lhs["list_literal_empty"],
|
|
"Expected an Union type annotation with an " "inner List type",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[int, torch.Tensor]",
|
|
lhs["list_literal_of_tensor"],
|
|
"Expected an Union type annotation with an " "inner List type",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[int, torch.Tensor]",
|
|
lhs["list_comprehension_of_tensor"],
|
|
"Expected an Union type annotation with an " "inner List type",
|
|
)
|
|
|
|
"""
|
|
Union[List[torch.Tensor], int]
|
|
"""
|
|
self._assert_passes(
|
|
template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
|
|
)
|
|
|
|
self._assert_passes(
|
|
template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[torch.Tensor], int]",
|
|
lhs["list_literal_of_str"],
|
|
r"List type annotation `List\[Tensor\]` did "
|
|
"not match the types of the given list "
|
|
"elements",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[torch.Tensor], int]",
|
|
lhs["list_literal_of_mixed"],
|
|
r"List type annotation `List\[Tensor\]` did "
|
|
"not match the types of the given list "
|
|
"elements",
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[List[torch.Tensor], int]",
|
|
lhs["list_comprehension_of_tensor"],
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[torch.Tensor], int]",
|
|
lhs["list_comprehension_of_str"],
|
|
r"List type annotation `List\[Tensor\]` did "
|
|
"not match the types of the given list "
|
|
"elements",
|
|
)
|
|
|
|
# TODO(@ansley): Support mixed list comprehensions
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[torch.Tensor], int]",
|
|
lhs["list_comprehension_of_mixed"],
|
|
"Arguments for call are not valid",
|
|
)
|
|
|
|
def test_union_with_dict_assignment(self):
|
|
template = dedent(
|
|
"""
|
|
def fn():
|
|
x: {ann} = {lhs}
|
|
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
|
|
x["foo"] = torch.tensor(3)
|
|
return x
|
|
"""
|
|
)
|
|
|
|
lhs = {
|
|
"dict_literal_empty": "{}",
|
|
"dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
|
|
"dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
|
|
"dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
|
|
"dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
|
|
zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
|
|
"dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
|
|
zip(["foo", "bar"], [1, 2]}',
|
|
"dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
|
|
zip(["foo", "bar"], [torch.arange(3), 2])}',
|
|
"dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
|
|
"dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
|
|
"dict_keyword_with_empty_iterable": "dict([])",
|
|
"dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
|
|
"dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
|
|
"dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
|
|
}
|
|
|
|
"""
|
|
Union[Dict[str, torch.Tensor], Dict[str, int]]
|
|
"""
|
|
self._assert_raises(
|
|
template,
|
|
"Union[List[str], List[torch.Tensor]]",
|
|
lhs["dict_literal_empty"],
|
|
"Expected an Union type annotation with an " "inner Dict type",
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_literal_of_str_tensor"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_literal_of_str_int"],
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_literal_of_mixed"],
|
|
"none of those dict types can hold the "
|
|
"types of the given keys and values",
|
|
)
|
|
|
|
# TODO: String frontend does not support tuple unpacking
|
|
# https://github.com/pytorch/pytorch/issues/64096
|
|
# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
# lhs["dict_comprehension_of_str_tensor"])
|
|
|
|
# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
# lhs["dict_comprehension_of_str_int"])
|
|
|
|
# self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
# lhs["dict_comprehension_of_mixed"],
|
|
# "foobar")
|
|
|
|
# self._assert_passes(template,
|
|
# "Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
# lhs["dict_keyword_with_internal_aggregate_function"])
|
|
|
|
# TODO(@ansley): Follow-up project needed for full type
|
|
# inference with dict keyword (supported for dict comprehension
|
|
# and dict literal already; should not be a blocker for anyone)
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_keyword"],
|
|
"full type inference is not yet supported",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_keyword_with_iterable"],
|
|
"full type inference is not yet supported",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_keyword_with_empty_iterable"],
|
|
"full type inference is not yet supported",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_keyword_with_mapping"],
|
|
"full type inference is not yet supported",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
|
lhs["dict_keyword_with_mapping_and_kwargs"],
|
|
"full type inference is not yet supported",
|
|
)
|
|
|
|
"""
|
|
Union[int, torch.Tensor]
|
|
"""
|
|
self._assert_raises(
|
|
template,
|
|
"Union[int, torch.Tensor]",
|
|
lhs["dict_literal_empty"],
|
|
"Expected an Union type annotation with " "an inner Dict type",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[int, torch.Tensor]",
|
|
lhs["dict_literal_of_str_tensor"],
|
|
"Expected an Union type annotation with " "an inner Dict type",
|
|
)
|
|
|
|
# See above--string frontend does not support tuple unpacking
|
|
# self._assert_raises(template, "Union[int, torch.Tensor]",
|
|
# lhs["dict_comprehension_of_tensor"],
|
|
# "foobar")
|
|
|
|
"""
|
|
Union[Dict[str, torch.Tensor], int]
|
|
"""
|
|
self._assert_passes(
|
|
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_literal_of_str_tensor"],
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_literal_of_str_int"],
|
|
"Type annotation was inferred to be "
|
|
r"`Dict\[str, Tensor\]`, but the type of "
|
|
"values given by the dict literal is",
|
|
)
|
|
|
|
self._assert_raises(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_literal_of_mixed"],
|
|
"Type annotation was inferred to be "
|
|
r"`Dict\[str, Tensor\]`, but the type of "
|
|
"values given by the dict literal is",
|
|
)
|
|
|
|
self._assert_passes(
|
|
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_keyword_with_iterable"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_keyword_with_empty_iterable"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_keyword_with_mapping"],
|
|
)
|
|
|
|
self._assert_passes(
|
|
template,
|
|
"Union[Dict[str, torch.Tensor], int]",
|
|
lhs["dict_keyword_with_mapping_and_kwargs"],
|
|
)
|
|
|
|
# See above--string frontend does not support tuple unpacking
|
|
# self._assert_passes(template,
|
|
# "Union[Dict[str, torch.Tensor], int]",
|
|
# lhs["dict_keyword_with_internal_aggregate_function"])
|
|
#
|
|
# self._assert_passes(template,
|
|
# "Union[Dict[str, torch.Tensor], int]",
|
|
# lhs["dict_comprehension_of_str_tensor"])
|
|
|
|
# self._assert_raises(template,
|
|
# "Union[Dict[str, torch.Tensor], int]",
|
|
# lhs["dict_comprehension_of_str_int"],
|
|
# "foobar")
|
|
|
|
# self._assert_raises(template,
|
|
# "Union[Dict[str, torch.Tensor], int]",
|
|
# lhs["dict_comprehension_of_mixed"],
|
|
# "foobar")
|