mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[JIT] Add out-of-source-tree to_backend tests (#41145)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41145 **Summary** This commit adds out-of-source-tree tests for `to_backend`. These tests check that a Module can be lowered to a backend, exported, loaded (in both Python and C++) and executed. **Fixes** This commit fixes #40067. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D22510076 Pulled By: SplitInfinity fbshipit-source-id: f65964ef3092a095740f06636ed5b1eb0884492d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0e7b9d4ff8
commit
4972cf06a2
54
test/custom_backend/test_custom_backend.py
Normal file
54
test/custom_backend/test_custom_backend.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from backend import Model, to_custom_backend, get_custom_backend_library_path
|
||||
|
||||
|
||||
class TestCustomBackend(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Load the library containing the custom backend.
|
||||
self.library_path = get_custom_backend_library_path()
|
||||
torch.ops.load_library(self.library_path)
|
||||
# Create an instance of the test Module and lower it for
|
||||
# the custom backend.
|
||||
self.model = to_custom_backend(torch.jit.script(Model()))
|
||||
|
||||
def test_execute(self):
|
||||
"""
|
||||
Test execution using the custom backend.
|
||||
"""
|
||||
a = torch.randn(4)
|
||||
b = torch.randn(4)
|
||||
# The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
|
||||
expected = (a + b, a - b)
|
||||
out = self.model(a, b)
|
||||
self.assertTrue(expected[0].allclose(out[0]))
|
||||
self.assertTrue(expected[1].allclose(out[1]))
|
||||
|
||||
def test_save_load(self):
|
||||
"""
|
||||
Test that a lowered module can be executed correctly
|
||||
after saving and loading.
|
||||
"""
|
||||
# Test execution before saving and loading to make sure
|
||||
# the lowered module works in the first place.
|
||||
self.test_execute()
|
||||
|
||||
# Save and load.
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
f.close()
|
||||
torch.jit.save(self.model, f.name)
|
||||
loaded = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
self.model = loaded
|
||||
|
||||
# Test execution again.
|
||||
self.test_execute()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user