mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
474 lines
16 KiB
Python
474 lines
16 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
|
|
import copy
|
|
import io
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import (
|
|
find_library_location,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo("skipping as a precaution")
|
|
class TestTorchbind(JitTestCase):
|
|
def setUp(self):
|
|
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
|
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
if IS_WINDOWS:
|
|
lib_file_path = find_library_location("torchbind_test.dll")
|
|
torch.ops.load_library(str(lib_file_path))
|
|
|
|
def test_torchbind(self):
|
|
def test_equality(f, cmp_key):
|
|
obj1 = f()
|
|
obj2 = torch.jit.script(f)()
|
|
return (cmp_key(obj1), cmp_key(obj2))
|
|
|
|
def f():
|
|
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
|
val.increment(1)
|
|
return val
|
|
|
|
test_equality(f, lambda x: x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
|
|
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
|
val.increment("foo")
|
|
|
|
def f():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
|
return ss.pop()
|
|
|
|
test_equality(f, lambda x: x)
|
|
|
|
def f():
|
|
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
|
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
|
ss1.push(ss2.pop())
|
|
return ss1.pop() + ss2.pop()
|
|
|
|
test_equality(f, lambda x: x)
|
|
|
|
# test nn module with prepare_scriptable function
|
|
class NonJitableClass:
|
|
def __init__(self, int1, int2):
|
|
self.int1 = int1
|
|
self.int2 = int2
|
|
|
|
def return_vals(self):
|
|
return self.int1, self.int2
|
|
|
|
class CustomWrapper(torch.nn.Module):
|
|
def __init__(self, foo):
|
|
super().__init__()
|
|
self.foo = foo
|
|
|
|
def forward(self) -> None:
|
|
self.foo.increment(1)
|
|
return
|
|
|
|
def __prepare_scriptable__(self):
|
|
int1, int2 = self.foo.return_vals()
|
|
foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
|
|
return CustomWrapper(foo)
|
|
|
|
foo = CustomWrapper(NonJitableClass(1, 2))
|
|
jit_foo = torch.jit.script(foo)
|
|
|
|
def test_torchbind_take_as_arg(self):
|
|
global StackString # see [local resolution in python]
|
|
StackString = torch.classes._TorchScriptTesting._StackString
|
|
|
|
def foo(stackstring):
|
|
# type: (StackString)
|
|
stackstring.push("lel")
|
|
return stackstring
|
|
|
|
script_input = torch.classes._TorchScriptTesting._StackString([])
|
|
scripted = torch.jit.script(foo)
|
|
script_output = scripted(script_input)
|
|
self.assertEqual(script_output.pop(), "lel")
|
|
|
|
def test_torchbind_return_instance(self):
|
|
def foo():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
|
return ss
|
|
|
|
scripted = torch.jit.script(foo)
|
|
# Ensure we are creating the object and calling __init__
|
|
# rather than calling the __init__wrapper nonsense
|
|
fc = (
|
|
FileCheck()
|
|
.check("prim::CreateObject()")
|
|
.check('prim::CallMethod[name="__init__"]')
|
|
)
|
|
fc.run(str(scripted.graph))
|
|
out = scripted()
|
|
self.assertEqual(out.pop(), "mom")
|
|
self.assertEqual(out.pop(), "hi")
|
|
|
|
def test_torchbind_return_instance_from_method(self):
|
|
def foo():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
|
clone = ss.clone()
|
|
ss.pop()
|
|
return ss, clone
|
|
|
|
scripted = torch.jit.script(foo)
|
|
out = scripted()
|
|
self.assertEqual(out[0].pop(), "hi")
|
|
self.assertEqual(out[1].pop(), "mom")
|
|
self.assertEqual(out[1].pop(), "hi")
|
|
|
|
def test_torchbind_def_property_getter_setter(self):
|
|
def foo_getter_setter_full():
|
|
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
|
|
# getX method intentionally adds 2 to x
|
|
old = fooGetterSetter.x
|
|
# setX method intentionally adds 2 to x
|
|
fooGetterSetter.x = old + 4
|
|
new = fooGetterSetter.x
|
|
return old, new
|
|
|
|
self.checkScript(foo_getter_setter_full, ())
|
|
|
|
def foo_getter_setter_lambda():
|
|
foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5)
|
|
old = foo.x
|
|
foo.x = old + 4
|
|
new = foo.x
|
|
return old, new
|
|
|
|
self.checkScript(foo_getter_setter_lambda, ())
|
|
|
|
def test_torchbind_def_property_just_getter(self):
|
|
def foo_just_getter():
|
|
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
|
|
# getY method intentionally adds 4 to x
|
|
return fooGetterSetter, fooGetterSetter.y
|
|
|
|
scripted = torch.jit.script(foo_just_getter)
|
|
out, result = scripted()
|
|
self.assertEqual(result, 10)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
|
|
out.y = 5
|
|
|
|
def foo_not_setter():
|
|
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
|
|
old = fooGetterSetter.y
|
|
fooGetterSetter.y = old + 4
|
|
# getY method intentionally adds 4 to x
|
|
return fooGetterSetter.y
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError,
|
|
"Tried to set read-only attribute: y",
|
|
"fooGetterSetter.y = old + 4",
|
|
):
|
|
scripted = torch.jit.script(foo_not_setter)
|
|
|
|
def test_torchbind_def_property_readwrite(self):
|
|
def foo_readwrite():
|
|
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
|
|
old = fooReadWrite.x
|
|
fooReadWrite.x = old + 4
|
|
return fooReadWrite.x, fooReadWrite.y
|
|
|
|
self.checkScript(foo_readwrite, ())
|
|
|
|
def foo_readwrite_error():
|
|
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
|
|
fooReadWrite.y = 5
|
|
return fooReadWrite
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
|
|
):
|
|
scripted = torch.jit.script(foo_readwrite_error)
|
|
|
|
def test_torchbind_take_instance_as_method_arg(self):
|
|
def foo():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
|
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
|
ss.merge(ss2)
|
|
return ss
|
|
|
|
scripted = torch.jit.script(foo)
|
|
out = scripted()
|
|
self.assertEqual(out.pop(), "hi")
|
|
self.assertEqual(out.pop(), "mom")
|
|
|
|
def test_torchbind_return_tuple(self):
|
|
def f():
|
|
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
|
|
return val.return_a_tuple()
|
|
|
|
scripted = torch.jit.script(f)
|
|
tup = scripted()
|
|
self.assertEqual(tup, (1337.0, 123))
|
|
|
|
def test_torchbind_save_load(self):
|
|
def foo():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
|
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
|
ss.merge(ss2)
|
|
return ss
|
|
|
|
scripted = torch.jit.script(foo)
|
|
self.getExportImportCopy(scripted)
|
|
|
|
def test_torchbind_lambda_method(self):
|
|
def foo():
|
|
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
|
return ss.top()
|
|
|
|
scripted = torch.jit.script(foo)
|
|
self.assertEqual(scripted(), "mom")
|
|
|
|
def test_torchbind_class_attr_recursive(self):
|
|
class FooBar(torch.nn.Module):
|
|
def __init__(self, foo_model):
|
|
super().__init__()
|
|
self.foo_mod = foo_model
|
|
|
|
def forward(self) -> int:
|
|
return self.foo_mod.info()
|
|
|
|
def to_ivalue(self):
|
|
torchbind_model = torch.classes._TorchScriptTesting._Foo(
|
|
self.foo_mod.info(), 1
|
|
)
|
|
return FooBar(torchbind_model)
|
|
|
|
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
|
|
scripted = torch.jit.script(inst.to_ivalue())
|
|
self.assertEqual(scripted(), 6)
|
|
|
|
def test_torchbind_class_attribute(self):
|
|
class FooBar1234(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
|
|
|
|
def forward(self):
|
|
return self.f.top()
|
|
|
|
inst = FooBar1234()
|
|
scripted = torch.jit.script(inst)
|
|
eic = self.getExportImportCopy(scripted)
|
|
assert eic() == "deserialized"
|
|
for expected in ["deserialized", "was", "i"]:
|
|
assert eic.f.pop() == expected
|
|
|
|
def test_torchbind_getstate(self):
|
|
class FooBar4321(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
def forward(self):
|
|
return self.f.top()
|
|
|
|
inst = FooBar4321()
|
|
scripted = torch.jit.script(inst)
|
|
eic = self.getExportImportCopy(scripted)
|
|
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
|
|
# return {1, 3, 3, 7}. I tried to make this actually depend on the
|
|
# values at instantiation in the test with some transformation, but
|
|
# because it seems we serialize/deserialize multiple times, that
|
|
# transformation isn't as you would it expect it to be.
|
|
assert eic() == 7
|
|
for expected in [7, 3, 3, 1]:
|
|
assert eic.f.pop() == expected
|
|
|
|
def test_torchbind_deepcopy(self):
|
|
class FooBar4321(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
def forward(self):
|
|
return self.f.top()
|
|
|
|
inst = FooBar4321()
|
|
scripted = torch.jit.script(inst)
|
|
copied = copy.deepcopy(scripted)
|
|
assert copied.forward() == 7
|
|
for expected in [7, 3, 3, 1]:
|
|
assert copied.f.pop() == expected
|
|
|
|
def test_torchbind_python_deepcopy(self):
|
|
class FooBar4321(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
def forward(self):
|
|
return self.f.top()
|
|
|
|
inst = FooBar4321()
|
|
copied = copy.deepcopy(inst)
|
|
assert copied() == 7
|
|
for expected in [7, 3, 3, 1]:
|
|
assert copied.f.pop() == expected
|
|
|
|
def test_torchbind_tracing(self):
|
|
class TryTracing(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
def forward(self):
|
|
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
|
|
|
|
traced = torch.jit.trace(TryTracing(), ())
|
|
self.assertEqual(torch.zeros(4, 4), traced())
|
|
|
|
def test_torchbind_pass_wrong_type(self):
|
|
with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
|
|
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
|
|
|
|
def test_torchbind_tracing_nested(self):
|
|
class TryTracingNest(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
class TryTracing123(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nest = TryTracingNest()
|
|
|
|
def forward(self):
|
|
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
|
|
|
|
traced = torch.jit.trace(TryTracing123(), ())
|
|
self.assertEqual(torch.zeros(4, 4), traced())
|
|
|
|
def test_torchbind_pickle_serialization(self):
|
|
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
b = io.BytesIO()
|
|
torch.save(nt, b)
|
|
b.seek(0)
|
|
# weights_only=False as trying to load ScriptObject
|
|
nt_loaded = torch.load(b, weights_only=False)
|
|
for exp in [7, 3, 3, 1]:
|
|
self.assertEqual(nt_loaded.pop(), exp)
|
|
|
|
def test_torchbind_instantiate_missing_class(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Tried to instantiate class 'foo.IDontExist', but it does not exist!",
|
|
):
|
|
torch.classes.foo.IDontExist(3, 4, 5)
|
|
|
|
def test_torchbind_optional_explicit_attr(self):
|
|
class TorchBindOptionalExplicitAttr(torch.nn.Module):
|
|
foo: Optional[torch.classes._TorchScriptTesting._StackString]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
|
|
|
def forward(self) -> str:
|
|
foo_obj = self.foo
|
|
if foo_obj is not None:
|
|
return foo_obj.pop()
|
|
else:
|
|
return "<None>"
|
|
|
|
mod = TorchBindOptionalExplicitAttr()
|
|
scripted = torch.jit.script(mod)
|
|
|
|
def test_torchbind_no_init(self):
|
|
with self.assertRaisesRegex(RuntimeError, "torch::init"):
|
|
x = torch.classes._TorchScriptTesting._NoInit()
|
|
|
|
def test_profiler_custom_op(self):
|
|
inst = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
with torch.autograd.profiler.profile() as prof:
|
|
torch.ops._TorchScriptTesting.take_an_instance(inst)
|
|
|
|
found_event = False
|
|
for e in prof.function_events:
|
|
if e.name == "_TorchScriptTesting::take_an_instance":
|
|
found_event = True
|
|
self.assertTrue(found_event)
|
|
|
|
def test_torchbind_getattr(self):
|
|
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
|
self.assertEqual(None, getattr(foo, "bar", None))
|
|
|
|
def test_torchbind_attr_exception(self):
|
|
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
|
with self.assertRaisesRegex(AttributeError, "does not have a field"):
|
|
foo.bar
|
|
|
|
def test_lambda_as_constructor(self):
|
|
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
|
|
self.assertEqual(obj_no_swap.diff(), 1)
|
|
|
|
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
|
|
self.assertEqual(obj_swap.diff(), -1)
|
|
|
|
def test_staticmethod(self):
|
|
def fn(inp: int) -> int:
|
|
return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp)
|
|
|
|
self.checkScript(fn, (1,))
|
|
|
|
def test_default_args(self):
|
|
def fn() -> int:
|
|
obj = torch.classes._TorchScriptTesting._DefaultArgs()
|
|
obj.increment(5)
|
|
obj.decrement()
|
|
obj.decrement(2)
|
|
obj.divide()
|
|
obj.scale_add(5)
|
|
obj.scale_add(3, 2)
|
|
obj.divide(3)
|
|
return obj.increment()
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def gn() -> int:
|
|
obj = torch.classes._TorchScriptTesting._DefaultArgs(5)
|
|
obj.increment(3)
|
|
obj.increment()
|
|
obj.decrement(2)
|
|
obj.divide()
|
|
obj.scale_add(3)
|
|
obj.scale_add(3, 2)
|
|
obj.divide(2)
|
|
return obj.decrement()
|
|
|
|
self.checkScript(gn, ())
|