mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test Update: this was not as straightforward as I originally imagined. Context below. **TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not. **Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`. Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`. (1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:070f389745/test/test_ops.py (L218-L221)
(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)
070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)
To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc). **Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)
) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)
To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384)
)). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129 Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
067a7b1d4a
commit
ae53510b9e
@ -125,12 +125,12 @@ class TestLinalg(TestCase):
|
||||
del os.environ["HIPBLASLT_ALLOW_TF32"]
|
||||
|
||||
def setUp(self):
|
||||
super(self.__class__, self).setUp()
|
||||
super().setUp()
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
def tearDown(self):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
super(self.__class__, self).tearDown()
|
||||
super().tearDown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tunableop_ctx(self):
|
||||
|
Reference in New Issue
Block a user