Files
pytorch/test/test_jit_py3.py
Zachary DeVito 83c347ff4a Remove prim::Constant op (#32804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32804

Constants are interpreter primitives so the op was not actually used.
This cleans up some of the logic around it.

This also fixes constant prop such that failures to look up an op
do not silently stop constant propagation. Instead, only errors
inside the op implementation itself will do this.

Test Plan: Imported from OSS

Differential Revision: D19673156

Pulled By: zdevito

fbshipit-source-id: 7beee59a6a67a6c2f8261d86bd505280fefa999e
2020-02-18 15:06:56 -08:00

413 lines
13 KiB
Python

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from typing import NamedTuple, List, Optional, Any, Dict
from jit.test_module_interface import TestModuleInterface # noqa: F401
import unittest
import sys
import torch
import torch.testing._internal.jit_utils
import torch.nn as nn
class TestScriptPy3(JitTestCase):
def test_joined_str(self):
def func(x):
hello, test = "Hello", "test"
print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
print(f"format blank")
hi = 'hi'
print(f"stuff before {hi}")
print(f"{hi} stuff after")
return x + 1
x = torch.arange(4., requires_grad=True)
# TODO: Add support for f-strings in string parser frontend
# self.checkScript(func, [x], optimize=True, capture_output=True)
with self.capture_stdout() as captured:
out = func(x)
scripted = torch.jit.script(func)
with self.capture_stdout() as captured_script:
out_script = func(x)
self.assertAlmostEqual(out, out_script)
self.assertEqual(captured, captured_script)
def test_optional_dict_construct(self):
class M(torch.nn.Module):
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
return buffer["prev_key"]
def forward(self, x):
prev_key = torch.rand(2, 3)
next_key = torch.rand(2, 3)
saved_state: Dict[str, Optional[torch.Tensor]] = {
"prev_key": prev_key,
"next_key": next_key,
}
return self.use(saved_state)
self.checkModule(M(), (torch.rand(2, 2),))
def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str = 2):
pass
torch.jit.script(M())
class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str):
return n_tokens, device_name
sm = torch.jit.script(M())
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
sm()
input = (3, 'hello')
self.assertEqual(sm(*input), input)
def test_named_tuple(self):
class FeatureVector(NamedTuple):
float_features: float
sequence_features: List[float]
time_since_first: float
@torch.jit.script
def foo(x) -> float:
fv = FeatureVector(3.0, [3.0], 3.0) # noqa
rv = fv.float_features
for val in fv.sequence_features:
rv += val
rv *= fv.time_since_first
return rv
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
def test_named_tuple_constant(self):
class Tup(NamedTuple):
a: int
b: int
@torch.jit.script
def foo():
return Tup(1, 2)
self.assertEqual(foo(), Tup(1, 2))
@unittest.skipIf(sys.version_info[0] < 3 and sys.version_info[1] < 6, "dict not ordered")
def test_dict_preserves_order(self):
def dict_ordering():
a : Dict[int, int] = {}
for i in range(1000):
a[i] = i + 1
return a
self.checkScript(dict_ordering, ())
di = torch.jit.script(dict_ordering)()
res = list(di.items())
for i in range(1000):
key, value = res[i]
self.assertTrue(key == i and value == i + 1)
def test_list_unification_hint(self):
with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
@torch.jit.script
def x():
b : int = [2, 3]
return b
def test_return_named_tuple(self):
class FeatureVector(NamedTuple):
float_features: float
sequence_features: List[float]
time_since_first: float
@torch.jit.script
def foo(x):
fv = FeatureVector(3.0, [3.0], 3.0)
return fv
out = foo(torch.rand(3, 4))
out = foo(torch.rand(3, 4))
self.assertEqual(out.float_features, 3.0)
self.assertEqual(out.sequence_features, [3.0])
self.assertEqual(out.time_since_first, 3.0)
def test_ignore_with_types(self):
@torch.jit.ignore
def fn(x: Dict[str, Optional[torch.Tensor]]):
return x + 10
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
self.dropout_modality(in_batch)
fn(in_batch)
return torch.tensor(1)
@torch.jit.ignore
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
return in_batch
sm = torch.jit.script(M())
FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
def test_python_callable(self):
class MyPythonClass(object):
@torch.jit.ignore
def __call__(self, *args) -> str:
return str(type(args[0]))
the_class = MyPythonClass()
@torch.jit.script
def fn(x):
return the_class(x)
# This doesn't involve the string frontend, so don't use checkScript
x = torch.ones(2)
self.assertEqual(fn(x), the_class(x))
def test_bad_types(self):
@torch.jit.ignore
def fn(my_arg):
return my_arg + 10
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
@torch.jit.script
def other_fn(x):
return fn('2')
def test_named_tuple_slice_unpack(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(a : int, b : float, c : List[int]):
tup = MyCoolNamedTuple(a, b, c) # noqa
my_a, my_b, my_c = tup
return tup[:1], my_a, my_c
self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6]))
def test_named_tuple_lower(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(a : int):
tup = MyCoolNamedTuple(a, 3.14, [9]) # noqa
return tup
FileCheck().check('TupleConstruct').run(foo.graph)
torch._C._jit_pass_lower_all_tuples(foo.graph)
FileCheck().check_not('TupleConstruct').run(foo.graph)
def test_named_tuple_type_annotation(self):
global MyCoolNamedTuple # see [local resolution in python]
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo(x : MyCoolNamedTuple) -> MyCoolNamedTuple:
return x
mnt = MyCoolNamedTuple(42, 420.0, [666])
self.assertEqual(foo(mnt), mnt)
def test_named_tuple_wrong_types(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int' for argument 'a'"
" but instead found type 'str'"):
@torch.jit.script
def foo():
tup = MyCoolNamedTuple('foo', 'bar', 'baz') # noqa
return tup
def test_named_tuple_kwarg_construct(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
@torch.jit.script
def foo():
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
return tup
tup = foo()
self.assertEqual(tup.a, 9)
self.assertEqual(tup.b, 3.5)
self.assertEqual(tup.c, [1, 2, 3])
def test_named_tuple_default_error(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int] = [3, 4, 5]
with self.assertRaisesRegex(RuntimeError, 'Default values are currently not supported'):
@torch.jit.script
def foo():
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
return tup
@unittest.skipIf(True, "broken while these tests were not in CI")
def test_named_tuple_serialization(self):
class MyCoolNamedTuple(NamedTuple):
a : int
b : float
c : List[int]
class MyMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
return MyCoolNamedTuple(3, 3.5, [3, 4, 5])
mm = MyMod()
mm.save('foo.zip')
torch.testing._internal.jit_utils.clear_class_registry()
loaded = torch.jit.load('foo.zip')
out = mm()
out_loaded = loaded()
for name in ['a', 'b', 'c']:
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
def test_type_annotate_py3(self):
def fn():
a : List[int] = []
b : torch.Tensor = torch.ones(2, 2)
c : Optional[torch.Tensor] = None
d : Optional[torch.Tensor] = torch.ones(3, 4)
for _ in range(10):
a.append(4)
c = torch.ones(2, 2)
d = None
return a, b, c, d
self.checkScript(fn, ())
def wrong_type():
wrong : List[int] = [0.5]
return wrong
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
torch.jit.script(wrong_type)
def test_parser_bug(self):
def parser_bug(o: Optional[torch.Tensor]):
pass
def test_mismatched_annotation(self):
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
@torch.jit.script
def foo():
x : str = 4
return x
def test_reannotate(self):
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
@torch.jit.script
def foo():
x = 5
if True:
x : Optional[int] = 7
def test_any_in_class_fails(self):
class MyCoolNamedTuple(NamedTuple):
a : Any
b : float
c : List[int]
with self.assertRaisesRegex(RuntimeError, "contains an Any"):
@torch.jit.script
def foo():
return MyCoolNamedTuple(4, 5.5, [3])
print(foo.graph)
def test_export_opnames_interface(self):
global OneTwoModule
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class FooMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x * y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 / x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x))
class M(nn.Module):
sub : OneTwoModule
def __init__(self):
super(M, self).__init__()
self.sub = BarMod()
def forward(self, x):
# type: (Tensor) -> Tensor
return self.sub.forward(x)
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
scripted_M_mod = torch.jit.script(M())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
['aten::add.Tensor', 'aten::mul.Scalar'])
if __name__ == '__main__':
run_tests()