[JIT] Add out-of-source-tree to_backend tests (#41145)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41145

**Summary**
This commit adds out-of-source-tree tests for `to_backend`. These tests check
that a Module can be lowered to a backend, exported, loaded (in both
Python and C++) and executed.

**Fixes**
This commit fixes #40067.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D22510076

Pulled By: SplitInfinity

fbshipit-source-id: f65964ef3092a095740f06636ed5b1eb0884492d
This commit is contained in:
Meghan Lele
2020-07-14 10:53:58 -07:00
committed by Facebook GitHub Bot
parent 0e7b9d4ff8
commit 4972cf06a2
12 changed files with 366 additions and 1 deletions

View File

@ -0,0 +1,54 @@
import os
import tempfile
import torch
import unittest
from backend import Model, to_custom_backend, get_custom_backend_library_path
class TestCustomBackend(unittest.TestCase):
def setUp(self):
# Load the library containing the custom backend.
self.library_path = get_custom_backend_library_path()
torch.ops.load_library(self.library_path)
# Create an instance of the test Module and lower it for
# the custom backend.
self.model = to_custom_backend(torch.jit.script(Model()))
def test_execute(self):
"""
Test execution using the custom backend.
"""
a = torch.randn(4)
b = torch.randn(4)
# The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
expected = (a + b, a - b)
out = self.model(a, b)
self.assertTrue(expected[0].allclose(out[0]))
self.assertTrue(expected[1].allclose(out[1]))
def test_save_load(self):
"""
Test that a lowered module can be executed correctly
after saving and loading.
"""
# Test execution before saving and loading to make sure
# the lowered module works in the first place.
self.test_execute()
# Save and load.
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
torch.jit.save(self.model, f.name)
loaded = torch.jit.load(f.name)
finally:
os.unlink(f.name)
self.model = loaded
# Test execution again.
self.test_execute()
if __name__ == "__main__":
unittest.main()