Files
pytorch/test/dynamo/test_deviceguard.py
brothergomez 366b24e242 [Inductor] Add a device agnostic DeviceGuard class to inductor (#123338)
Summary: Currently although only in one place in inductor, the `device` context manager from the device interface is used . This PR creates an inductor specific `DeviceGuard` class for use in these cases, which keeps a reference to the `DeviceInterface` class which is defined and added out of tree. This then offloads the device specific work to the device interface, instead of having to define this logic on the device class which isn't strictly necessary for inductor.

Ideally I would have used the existing `DeviceGuard` class, but these are defined per device and don't work well with inductor's device agnostic/ out of tree compatible design. With the existing classes in mind, I am happy to take suggestions on the renaming of this class.

Whilst I was there, I also took the opportunity to rename `gpu_device` to `device_interface` to clarify this is not necessarily a GPU.

Test Plan: None currently, happy to add some.

Co-authored-by: Matthew Haddock <matthewha@graphcore.ai>
Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123338
Approved by: https://github.com/aakhundov
2024-04-12 18:21:27 +00:00

93 lines
3.0 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
from unittest.mock import Mock
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.device_interface import CudaInterface, DeviceGuard
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
class TestDeviceGuard(torch._dynamo.test_case.TestCase):
"""
Unit tests for the DeviceGuard class using a mock DeviceInterface.
"""
def setUp(self):
super().setUp()
self.device_interface = Mock()
self.device_interface.exchange_device = Mock(return_value=0)
self.device_interface.maybe_exchange_device = Mock(return_value=0)
def test_device_guard(self):
device_guard = DeviceGuard(self.device_interface, 1)
with device_guard as _:
self.device_interface.exchange_device.assert_called_once_with(1)
self.assertEqual(device_guard.prev_idx, 0)
self.assertEqual(device_guard.idx, 1)
self.device_interface.maybe_exchange_device.assert_called_once_with(0)
self.assertEqual(device_guard.prev_idx, 0)
self.assertEqual(device_guard.idx, 0)
def test_device_guard_no_index(self):
device_guard = DeviceGuard(self.device_interface, None)
with device_guard as _:
self.device_interface.exchange_device.assert_not_called()
self.assertEqual(device_guard.prev_idx, -1)
self.assertEqual(device_guard.idx, None)
self.device_interface.maybe_exchange_device.assert_not_called()
self.assertEqual(device_guard.prev_idx, -1)
self.assertEqual(device_guard.idx, None)
@unittest.skipIf(not TEST_CUDA, "No CUDA available.")
class TestCUDADeviceGuard(torch._dynamo.test_case.TestCase):
"""
Unit tests for the DeviceGuard class using a CudaInterface.
"""
def setUp(self):
super().setUp()
self.device_interface = CudaInterface
@unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
def test_device_guard(self):
current_device = torch.cuda.current_device()
device_guard = DeviceGuard(self.device_interface, 1)
with device_guard as _:
self.assertEqual(torch.cuda.current_device(), 1)
self.assertEqual(device_guard.prev_idx, 0)
self.assertEqual(device_guard.idx, 1)
self.assertEqual(torch.cuda.current_device(), current_device)
self.assertEqual(device_guard.prev_idx, 0)
self.assertEqual(device_guard.idx, 0)
def test_device_guard_no_index(self):
current_device = torch.cuda.current_device()
device_guard = DeviceGuard(self.device_interface, None)
with device_guard as _:
self.assertEqual(torch.cuda.current_device(), current_device)
self.assertEqual(device_guard.prev_idx, -1)
self.assertEqual(device_guard.idx, None)
self.assertEqual(device_guard.prev_idx, -1)
self.assertEqual(device_guard.idx, None)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()