mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 12:44:37 +08:00
In a trunk failure today, we saw the same test running on both trunk and slow shards. The reason is that this test didn't invoke `super().setUp()`, so all the test features like slow and disabled test didn't apply to them. I use Claude to find all test classes with a `setUp()` method that didn't called `super().setUp()` and patch all of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167163 Approved by: https://github.com/malfet
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
from backend import get_custom_backend_library_path, Model, to_custom_backend
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestCustomBackend(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Load the library containing the custom backend.
|
|
self.library_path = get_custom_backend_library_path()
|
|
torch.ops.load_library(self.library_path)
|
|
# Create an instance of the test Module and lower it for
|
|
# the custom backend.
|
|
self.model = to_custom_backend(torch.jit.script(Model()))
|
|
|
|
def test_execute(self):
|
|
"""
|
|
Test execution using the custom backend.
|
|
"""
|
|
a = torch.randn(4)
|
|
b = torch.randn(4)
|
|
# The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
|
|
expected = (a + b, a - b)
|
|
out = self.model(a, b)
|
|
self.assertTrue(expected[0].allclose(out[0]))
|
|
self.assertTrue(expected[1].allclose(out[1]))
|
|
|
|
def test_save_load(self):
|
|
"""
|
|
Test that a lowered module can be executed correctly
|
|
after saving and loading.
|
|
"""
|
|
# Test execution before saving and loading to make sure
|
|
# the lowered module works in the first place.
|
|
self.test_execute()
|
|
|
|
# Save and load.
|
|
f = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
f.close()
|
|
torch.jit.save(self.model, f.name)
|
|
loaded = torch.jit.load(f.name)
|
|
finally:
|
|
os.unlink(f.name)
|
|
self.model = loaded
|
|
|
|
# Test execution again.
|
|
self.test_execute()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|