mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32804 Constants are interpreter primitives so the op was not actually used. This cleans up some of the logic around it. This also fixes constant prop such that failures to look up an op do not silently stop constant propagation. Instead, only errors inside the op implementation itself will do this. Test Plan: Imported from OSS Differential Revision: D19673156 Pulled By: zdevito fbshipit-source-id: 7beee59a6a67a6c2f8261d86bd505280fefa999e
413 lines
13 KiB
Python
413 lines
13 KiB
Python
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.testing import FileCheck
|
|
from typing import NamedTuple, List, Optional, Any, Dict
|
|
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
|
import unittest
|
|
import sys
|
|
import torch
|
|
import torch.testing._internal.jit_utils
|
|
import torch.nn as nn
|
|
|
|
class TestScriptPy3(JitTestCase):
|
|
def test_joined_str(self):
|
|
def func(x):
|
|
hello, test = "Hello", "test"
|
|
print(f"{hello + ' ' + test}, I'm a {test}") # noqa E999
|
|
print(f"format blank")
|
|
hi = 'hi'
|
|
print(f"stuff before {hi}")
|
|
print(f"{hi} stuff after")
|
|
return x + 1
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
# TODO: Add support for f-strings in string parser frontend
|
|
# self.checkScript(func, [x], optimize=True, capture_output=True)
|
|
|
|
with self.capture_stdout() as captured:
|
|
out = func(x)
|
|
|
|
scripted = torch.jit.script(func)
|
|
with self.capture_stdout() as captured_script:
|
|
out_script = func(x)
|
|
|
|
self.assertAlmostEqual(out, out_script)
|
|
self.assertEqual(captured, captured_script)
|
|
|
|
def test_optional_dict_construct(self):
|
|
class M(torch.nn.Module):
|
|
def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
|
|
return buffer["prev_key"]
|
|
|
|
def forward(self, x):
|
|
prev_key = torch.rand(2, 3)
|
|
next_key = torch.rand(2, 3)
|
|
saved_state: Dict[str, Optional[torch.Tensor]] = {
|
|
"prev_key": prev_key,
|
|
"next_key": next_key,
|
|
}
|
|
|
|
return self.use(saved_state)
|
|
|
|
self.checkModule(M(), (torch.rand(2, 2),))
|
|
|
|
def test_kwarg_support(self):
|
|
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str = 2):
|
|
pass
|
|
torch.jit.script(M())
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, *, n_tokens: int, device_name: str):
|
|
return n_tokens, device_name
|
|
|
|
sm = torch.jit.script(M())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
|
|
sm()
|
|
|
|
input = (3, 'hello')
|
|
self.assertEqual(sm(*input), input)
|
|
|
|
def test_named_tuple(self):
|
|
class FeatureVector(NamedTuple):
|
|
float_features: float
|
|
sequence_features: List[float]
|
|
time_since_first: float
|
|
|
|
@torch.jit.script
|
|
def foo(x) -> float:
|
|
fv = FeatureVector(3.0, [3.0], 3.0) # noqa
|
|
rv = fv.float_features
|
|
for val in fv.sequence_features:
|
|
rv += val
|
|
rv *= fv.time_since_first
|
|
return rv
|
|
|
|
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
|
|
|
|
def test_named_tuple_constant(self):
|
|
class Tup(NamedTuple):
|
|
a: int
|
|
b: int
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
return Tup(1, 2)
|
|
|
|
self.assertEqual(foo(), Tup(1, 2))
|
|
|
|
@unittest.skipIf(sys.version_info[0] < 3 and sys.version_info[1] < 6, "dict not ordered")
|
|
def test_dict_preserves_order(self):
|
|
def dict_ordering():
|
|
a : Dict[int, int] = {}
|
|
for i in range(1000):
|
|
a[i] = i + 1
|
|
return a
|
|
|
|
self.checkScript(dict_ordering, ())
|
|
di = torch.jit.script(dict_ordering)()
|
|
res = list(di.items())
|
|
for i in range(1000):
|
|
key, value = res[i]
|
|
self.assertTrue(key == i and value == i + 1)
|
|
|
|
def test_list_unification_hint(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
|
|
@torch.jit.script
|
|
def x():
|
|
b : int = [2, 3]
|
|
return b
|
|
|
|
def test_return_named_tuple(self):
|
|
class FeatureVector(NamedTuple):
|
|
float_features: float
|
|
sequence_features: List[float]
|
|
time_since_first: float
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
fv = FeatureVector(3.0, [3.0], 3.0)
|
|
return fv
|
|
|
|
out = foo(torch.rand(3, 4))
|
|
out = foo(torch.rand(3, 4))
|
|
self.assertEqual(out.float_features, 3.0)
|
|
self.assertEqual(out.sequence_features, [3.0])
|
|
self.assertEqual(out.time_since_first, 3.0)
|
|
|
|
def test_ignore_with_types(self):
|
|
@torch.jit.ignore
|
|
def fn(x: Dict[str, Optional[torch.Tensor]]):
|
|
return x + 10
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
|
|
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
|
|
self.dropout_modality(in_batch)
|
|
fn(in_batch)
|
|
return torch.tensor(1)
|
|
|
|
@torch.jit.ignore
|
|
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
|
|
return in_batch
|
|
|
|
sm = torch.jit.script(M())
|
|
FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
|
|
|
|
def test_python_callable(self):
|
|
class MyPythonClass(object):
|
|
@torch.jit.ignore
|
|
def __call__(self, *args) -> str:
|
|
return str(type(args[0]))
|
|
|
|
the_class = MyPythonClass()
|
|
@torch.jit.script
|
|
def fn(x):
|
|
return the_class(x)
|
|
|
|
# This doesn't involve the string frontend, so don't use checkScript
|
|
x = torch.ones(2)
|
|
self.assertEqual(fn(x), the_class(x))
|
|
|
|
def test_bad_types(self):
|
|
@torch.jit.ignore
|
|
def fn(my_arg):
|
|
return my_arg + 10
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
|
|
@torch.jit.script
|
|
def other_fn(x):
|
|
return fn('2')
|
|
|
|
|
|
def test_named_tuple_slice_unpack(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(a : int, b : float, c : List[int]):
|
|
tup = MyCoolNamedTuple(a, b, c) # noqa
|
|
my_a, my_b, my_c = tup
|
|
return tup[:1], my_a, my_c
|
|
|
|
self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6]))
|
|
|
|
def test_named_tuple_lower(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(a : int):
|
|
tup = MyCoolNamedTuple(a, 3.14, [9]) # noqa
|
|
return tup
|
|
|
|
FileCheck().check('TupleConstruct').run(foo.graph)
|
|
torch._C._jit_pass_lower_all_tuples(foo.graph)
|
|
FileCheck().check_not('TupleConstruct').run(foo.graph)
|
|
|
|
def test_named_tuple_type_annotation(self):
|
|
global MyCoolNamedTuple # see [local resolution in python]
|
|
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo(x : MyCoolNamedTuple) -> MyCoolNamedTuple:
|
|
return x
|
|
|
|
mnt = MyCoolNamedTuple(42, 420.0, [666])
|
|
self.assertEqual(foo(mnt), mnt)
|
|
|
|
def test_named_tuple_wrong_types(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int' for argument 'a'"
|
|
" but instead found type 'str'"):
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple('foo', 'bar', 'baz') # noqa
|
|
return tup
|
|
|
|
def test_named_tuple_kwarg_construct(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
|
|
return tup
|
|
|
|
tup = foo()
|
|
self.assertEqual(tup.a, 9)
|
|
self.assertEqual(tup.b, 3.5)
|
|
self.assertEqual(tup.c, [1, 2, 3])
|
|
|
|
def test_named_tuple_default_error(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int] = [3, 4, 5]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Default values are currently not supported'):
|
|
@torch.jit.script
|
|
def foo():
|
|
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
|
|
return tup
|
|
|
|
@unittest.skipIf(True, "broken while these tests were not in CI")
|
|
def test_named_tuple_serialization(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : int
|
|
b : float
|
|
c : List[int]
|
|
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return MyCoolNamedTuple(3, 3.5, [3, 4, 5])
|
|
|
|
mm = MyMod()
|
|
mm.save('foo.zip')
|
|
torch.testing._internal.jit_utils.clear_class_registry()
|
|
loaded = torch.jit.load('foo.zip')
|
|
|
|
out = mm()
|
|
out_loaded = loaded()
|
|
|
|
for name in ['a', 'b', 'c']:
|
|
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
|
|
|
|
def test_type_annotate_py3(self):
|
|
def fn():
|
|
a : List[int] = []
|
|
b : torch.Tensor = torch.ones(2, 2)
|
|
c : Optional[torch.Tensor] = None
|
|
d : Optional[torch.Tensor] = torch.ones(3, 4)
|
|
for _ in range(10):
|
|
a.append(4)
|
|
c = torch.ones(2, 2)
|
|
d = None
|
|
return a, b, c, d
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def wrong_type():
|
|
wrong : List[int] = [0.5]
|
|
return wrong
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
|
|
torch.jit.script(wrong_type)
|
|
|
|
def test_parser_bug(self):
|
|
def parser_bug(o: Optional[torch.Tensor]):
|
|
pass
|
|
|
|
def test_mismatched_annotation(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
|
|
@torch.jit.script
|
|
def foo():
|
|
x : str = 4
|
|
return x
|
|
|
|
def test_reannotate(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
|
|
@torch.jit.script
|
|
def foo():
|
|
x = 5
|
|
if True:
|
|
x : Optional[int] = 7
|
|
|
|
def test_any_in_class_fails(self):
|
|
class MyCoolNamedTuple(NamedTuple):
|
|
a : Any
|
|
b : float
|
|
c : List[int]
|
|
with self.assertRaisesRegex(RuntimeError, "contains an Any"):
|
|
@torch.jit.script
|
|
def foo():
|
|
return MyCoolNamedTuple(4, 5.5, [3])
|
|
print(foo.graph)
|
|
|
|
def test_export_opnames_interface(self):
|
|
global OneTwoModule
|
|
@torch.jit.interface
|
|
class OneTwoModule(nn.Module):
|
|
def one(self, x, y):
|
|
# type: (Tensor, Tensor) -> Tensor
|
|
pass
|
|
|
|
def two(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
pass
|
|
|
|
def forward(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
pass
|
|
|
|
class FooMod(nn.Module):
|
|
def one(self, x, y):
|
|
# type: (Tensor, Tensor) -> Tensor
|
|
return x + y
|
|
|
|
def two(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return 2 * x
|
|
|
|
def forward(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return self.one(self.two(x), x)
|
|
|
|
class BarMod(nn.Module):
|
|
def one(self, x, y):
|
|
# type: (Tensor, Tensor) -> Tensor
|
|
return x * y
|
|
|
|
def two(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return 2 / x
|
|
|
|
def forward(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return self.two(self.one(x, x))
|
|
|
|
class M(nn.Module):
|
|
sub : OneTwoModule
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.sub = BarMod()
|
|
|
|
def forward(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return self.sub.forward(x)
|
|
|
|
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
|
|
return mod_list[0].forward(x) + mod_list[1].forward(x)
|
|
|
|
scripted_M_mod = torch.jit.script(M())
|
|
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
|
['aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'])
|
|
|
|
scripted_M_mod.sub = torch.jit.script(FooMod())
|
|
self.assertEqual(torch.jit.export_opnames(scripted_M_mod),
|
|
['aten::add.Tensor', 'aten::mul.Scalar'])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|