Files
pytorch/test/jit/test_module_containers.py

757 lines
25 KiB
Python

# Owner(s): ["oncall: jit"]
import os
import sys
from collections import OrderedDict
from typing import Any, List, Tuple
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
class TestModuleContainers(JitTestCase):
def test_sequential_intermediary_types(self):
class A(torch.nn.Module):
def forward(self, x):
return x + 3
class B(torch.nn.Module):
def forward(self, x):
return {"1": x}
class C(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = torch.nn.Sequential(A(), B())
def forward(self, x):
return self.foo(x)
self.checkModule(C(), (torch.tensor(1),))
def test_moduledict(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class Inner2(torch.nn.Module):
def forward(self, x):
return x * 2
class Inner3(torch.nn.Module):
def forward(self, x):
return (x - 4) * 3
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
modules = OrderedDict(
[
("one", Inner()),
("two", Inner2()),
("three", Inner3()),
]
)
self.moduledict = nn.ModuleDict(modules)
def forward(self, x, skip_name):
# type: (Tensor, str)
names = torch.jit.annotate(List[str], [])
values = []
for name in self.moduledict:
names.append(name)
for name, mod in self.moduledict.items():
if name != skip_name:
names.append(name)
x = mod(x)
values.append(x)
for mod in self.moduledict.values():
x = mod(x)
values.append(x)
for key in self.moduledict.keys():
names.append(key)
return x, names
class M2(M):
def forward(self, x, skip_name):
# type: (Tensor, str)
names = torch.jit.annotate(List[str], [])
values = []
x2 = x
iter = 0
for name in self.moduledict:
names.append(name)
for i, (name, mod) in enumerate(self.moduledict.items()):
iter += i
if name != skip_name:
names.append(name)
x = mod(x)
values.append(x)
for i, mod in enumerate(self.moduledict.values()):
iter += i
x = mod(x)
values.append(x)
for i, key in enumerate(self.moduledict.keys()):
iter += i
names.append(key)
for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
iter += i
x2 = mod(mod(x2))
return x, x2, names, iter
for name in ["", "one", "two", "three"]:
inp = torch.tensor(1)
self.checkModule(M(), (inp, name))
self.checkModule(M2(), (inp, name))
def test_custom_container_forward(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class CustomSequential(nn.Sequential):
def __init__(self) -> None:
super().__init__(nn.ReLU(), Inner())
def forward(self, x):
x = x + 3
for mod in self:
x = mod(x)
return x - 5
self.checkModule(CustomSequential(), (torch.tensor(0.5),))
class CustomModuleList(nn.ModuleList):
def __init__(self) -> None:
super().__init__([nn.ReLU(), Inner()])
def forward(self, x):
x = x + 3
for mod in self:
x = mod(x)
return x - 5
self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
class CustomModuleDict(nn.ModuleDict):
def __init__(self) -> None:
super().__init__(
OrderedDict(
[
("one", Inner()),
("two", nn.ReLU()),
("three", Inner()),
]
)
)
def forward(self, x):
x = x + 3
names = torch.jit.annotate(List[str], [])
for name, mod in self.items():
x = mod(x)
names.append(name)
return names, x - 5
self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule):
def __init__(self, mod_list):
super().__init__()
self.mods = mod_list
@torch.jit.script_method
def forward(self, v):
for m in self.mods:
v = m(v)
return v
with torch.jit.optimized_execution(False):
m = M(nn.Sequential(nn.ReLU()))
self.assertExportImportModule(m, (torch.randn(2, 2),))
def test_script_modulelist_index(self):
class Sub(torch.nn.Module):
def __init__(self, i):
super().__init__()
self.i = i
def forward(self, thing):
return thing - self.i
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
def forward(self, v):
v = self.mods[4].forward(v)
v = self.mods[-1].forward(v)
v = self.mods[-9].forward(v)
return v
x = torch.tensor(1)
self.checkModule(M(), (x,))
class MForward(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
def forward(self, v):
v = self.mods[4](v)
v = self.mods[-1](v)
v = self.mods[-9](v)
return v
self.checkModule(MForward(), (torch.tensor(1),))
class M2(M):
def forward(self, v):
return self.mods[-11].forward(v)
with self.assertRaisesRegexWithHighlight(
Exception, "Index -11 out of range", "self.mods[-11]"
):
torch.jit.script(M2())
class M3(M):
def forward(self, v):
i = 3
return self.mods[i].forward(v)
with self.assertRaisesRegexWithHighlight(
Exception, "Enumeration is supported", "self.mods[i]"
):
torch.jit.script(M3())
class M4(M):
def forward(self, v):
i = 3
return self.mods[i].forward(v)
with self.assertRaisesRegex(Exception, "will fail because i is not a literal"):
torch.jit.script(M4())
def test_module_interface_special_methods(self):
class CustomModuleInterface(torch.nn.Module):
pass
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleList.__init__(self, modules)
class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.Sequential.__init__(self, modules)
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleDict.__init__(self, modules)
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.modulelist = CustomModuleList([self.submod])
self.sequential = CustomSequential(self.submod)
self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs):
assert self.modulelist[0] is self.submod, (
"__getitem__ failing for ModuleList"
)
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList"
assert self.sequential[0] is self.submod, (
"__getitem__ failing for Sequential"
)
assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential"
assert self.moduledict["submod"] is self.submod, (
"__getitem__ failing for ModuleDict"
)
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently
i = 0
for key in self.moduledict:
i += 1
assert i == len(self.moduledict), "iteration failing for ModuleDict"
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
for key in self.moduledict.keys():
assert key == "submod", "keys() fails for ModuleDict"
for item in self.moduledict.items():
assert item[0] == "submod", "items() fails for ModuleDict"
assert item[1] is self.submod, "items() fails for ModuleDict"
for value in self.moduledict.values():
assert value is self.submod, "values() fails for ModuleDict"
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_special_method_with_override(self):
class CustomModuleInterface(torch.nn.Module):
pass
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleList.__init__(self, modules)
def __len__(self):
# this is arbitrary, just to check that the overridden py __len__ from
# CustomModuleList takes precedence over the automatically generated
# __len__ added by the jit compiler
return 2
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.modulelist = CustomModuleList([self.submod])
def forward(self, inputs):
assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
torch.jit.script(m)
def test_moduledict_getitem(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.jit.script(torch.nn.ReLU())
self.tanh = torch.jit.script(torch.nn.Tanh())
self.moduledict = torch.nn.ModuleDict(
{"relu": self.relu, "tanh": self.tanh}
)
def forward(self, input):
assert self.moduledict["relu"] is self.relu
assert self.moduledict["tanh"] is self.tanh
return input
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_moduledict_keyerror(self):
class BadModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
def forward(self, input):
assert self.moduledict["blah"] == "blah", "this is a keyerror"
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
):
b = BadModule()
torch.jit.script(b)
class AnotherBadModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
def forward(self, input):
idx = "blah"
assert self.moduledict[idx] == "blah", "this is a string literal error"
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Unable to extract string literal index. "
"ModuleDict indexing is only supported with string literals. "
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
"because i is not a literal.",
"self.moduledict[idx]",
):
b = AnotherBadModule()
torch.jit.script(b)
def test_normal_list_attribute_with_modules_error(self):
"""
Test that an attempt to script a module with a regular list attribute
containing other modules fails with a relevant error message.
"""
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = [torch.nn.ReLU(), torch.nn.ReLU()]
def forward(self):
return len(self.a)
error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module"
with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"):
torch.jit.script(Mod())
def test_empty_dict_override_contains(self):
class CustomModuleInterface(torch.nn.Module):
pass
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleDict.__init__(self, modules)
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.moduledict = CustomModuleDict()
def forward(self, inputs):
assert "submod" not in self.moduledict, (
"__contains__ fails for ModuleDict"
)
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_typed_module_dict(self):
"""
Test that a type annotation can be provided for a ModuleDict that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
pass
class ImplementsInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
if isinstance(inp, torch.Tensor):
return torch.max(inp, dim=0)
return inp
class DoesNotImplementInterface(torch.nn.Module):
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.max(inp, dim=0)
# Test annotation of submodule.
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
value: ModuleInterface = self.d[key]
return value.forward(x)
m = Mod()
self.checkModule(m, (torch.randn(2, 2), "module"))
# Test annotation of self.
class ModDict(torch.nn.ModuleDict):
def __init__(self) -> None:
super().__init__({"module": ImplementsInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
submodule: ModuleInterface = self[key]
return submodule.forward(x)
m = ModDict()
self.checkModule(m, (torch.randn(2, 2), "module"))
# Test error message thrown when annotated attribute does not comply with the
# annotation.
class ModWithWrongAnnotation(torch.nn.ModuleDict):
def __init__(self) -> None:
super().__init__()
self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
submodule: ModuleInterface = self.d[key]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
):
torch.jit.script(ModWithWrongAnnotation())
def test_typed_module_list(self):
"""
Test that a type annotation can be provided for a ModuleList that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
pass
class ImplementsInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
if isinstance(inp, torch.Tensor):
return torch.max(inp, dim=0)
return inp
class DoesNotImplementInterface(torch.nn.Module):
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.max(inp, dim=0)
# Test annotation of submodule.
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.l = torch.nn.ModuleList([ImplementsInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
value: ModuleInterface = self.l[idx]
return value.forward(x)
m = Mod()
self.checkModule(m, (torch.randn(2, 2), 0))
# Test annotation of self.
class ModList(torch.nn.ModuleList):
def __init__(self) -> None:
super().__init__([ImplementsInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
submodule: ModuleInterface = self[idx]
return submodule.forward(x)
m = ModList()
self.checkModule(m, (torch.randn(2, 2), 0))
# Test error message thrown when annotated attribute does not comply with the
# annotation.
class ModWithWrongAnnotation(torch.nn.ModuleList):
def __init__(self) -> None:
super().__init__()
self.l = torch.nn.ModuleList([DoesNotImplementInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
submodule: ModuleInterface = self.l[idx]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
):
torch.jit.script(ModWithWrongAnnotation())
def test_module_properties(self):
class ModuleWithProperties(torch.nn.Module):
__jit_unused_properties__ = ["ignored_attr"]
def __init__(self, a: int):
super().__init__()
self.a = a
def forward(self, a: int, b: int):
self.attr = a + b
return self.attr
@property
def attr(self):
return self.a
@property
def ignored_attr(self):
return sum([self.a])
@torch.jit.unused
@property
def ignored_attr_2(self):
return sum([self.a])
@ignored_attr_2.setter
def ignored_attr_2(self, value):
self.a = sum([self.a])
@attr.setter
def attr(self, a: int):
if a > 0:
self.a = a
else:
self.a = 0
class ModuleWithNoSetter(torch.nn.Module):
def __init__(self, a: int):
super().__init__()
self.a = a
def forward(self, a: int, b: int):
self.attr + a + b
@property
def attr(self):
return self.a + 1
self.checkModule(
ModuleWithProperties(5),
(
5,
6,
),
)
self.checkModule(
ModuleWithProperties(5),
(
-5,
-6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
5,
6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
-5,
-6,
),
)
mod = ModuleWithProperties(3)
scripted_mod = torch.jit.script(mod)
with self.assertRaisesRegex(AttributeError, "has no attribute"):
scripted_mod.ignored_attr
def test_module_inplace_construct(self):
class M(nn.Module):
def __init__(self, start: int):
super().__init__()
self.linear = nn.Linear(3, 3)
self.attribute = start
self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
def method(self) -> int:
return self.attribute
@torch.jit.unused
def unused_method(self):
return self.attribute + self.attribute
def forward(self, x):
return self.linear(self.linear(x))
class N(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(4, 4)
@torch.jit.ignore
def ignored_method(self, x):
return x
def forward(self, x):
return self.linear(x)
m = torch.jit.script(M(3))
n = torch.jit.script(N())
n._reconstruct(m._c)
inp = torch.rand((3))
# Check that both modules produce the same output.
with torch.no_grad():
m_out = m(inp)
n_out = n(inp)
self.assertEqual(m_out, n_out)
# Check that ignored method is still intact.
self.assertEqual(inp, n.ignored_method(inp))
def test_parameterlist_script_getitem(self):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x):
self.module_list[0]
self.parameter_list[0]
return x
self.checkModule(MyModule(), (torch.zeros(1)))
def test_parameterlist_script_iter(self):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x):
r = x
for i, p in enumerate(self.parameter_list):
r = r + p + i
return r
self.checkModule(MyModule(), (torch.zeros(1),))
def test_parameterdict_script_getitem(self):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.parameter_dict = nn.ParameterDict(
{k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
)
def forward(self, x):
return (
self.parameter_dict["a"] * x
+ self.parameter_dict["b"] * self.parameter_dict["c"]
)
self.checkModule(MyModule(), (torch.ones(1),))
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")