mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Currently although only in one place in inductor, the `device` context manager from the device interface is used . This PR creates an inductor specific `DeviceGuard` class for use in these cases, which keeps a reference to the `DeviceInterface` class which is defined and added out of tree. This then offloads the device specific work to the device interface, instead of having to define this logic on the device class which isn't strictly necessary for inductor. Ideally I would have used the existing `DeviceGuard` class, but these are defined per device and don't work well with inductor's device agnostic/ out of tree compatible design. With the existing classes in mind, I am happy to take suggestions on the renaming of this class. Whilst I was there, I also took the opportunity to rename `gpu_device` to `device_interface` to clarify this is not necessarily a GPU. Test Plan: None currently, happy to add some. Co-authored-by: Matthew Haddock <matthewha@graphcore.ai> Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123338 Approved by: https://github.com/aakhundov
93 lines
3.0 KiB
Python
93 lines
3.0 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.device_interface import CudaInterface, DeviceGuard
|
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
|
|
|
|
|
class TestDeviceGuard(torch._dynamo.test_case.TestCase):
|
|
"""
|
|
Unit tests for the DeviceGuard class using a mock DeviceInterface.
|
|
"""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.device_interface = Mock()
|
|
|
|
self.device_interface.exchange_device = Mock(return_value=0)
|
|
self.device_interface.maybe_exchange_device = Mock(return_value=0)
|
|
|
|
def test_device_guard(self):
|
|
device_guard = DeviceGuard(self.device_interface, 1)
|
|
|
|
with device_guard as _:
|
|
self.device_interface.exchange_device.assert_called_once_with(1)
|
|
self.assertEqual(device_guard.prev_idx, 0)
|
|
self.assertEqual(device_guard.idx, 1)
|
|
|
|
self.device_interface.maybe_exchange_device.assert_called_once_with(0)
|
|
self.assertEqual(device_guard.prev_idx, 0)
|
|
self.assertEqual(device_guard.idx, 0)
|
|
|
|
def test_device_guard_no_index(self):
|
|
device_guard = DeviceGuard(self.device_interface, None)
|
|
|
|
with device_guard as _:
|
|
self.device_interface.exchange_device.assert_not_called()
|
|
self.assertEqual(device_guard.prev_idx, -1)
|
|
self.assertEqual(device_guard.idx, None)
|
|
|
|
self.device_interface.maybe_exchange_device.assert_not_called()
|
|
self.assertEqual(device_guard.prev_idx, -1)
|
|
self.assertEqual(device_guard.idx, None)
|
|
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No CUDA available.")
|
|
class TestCUDADeviceGuard(torch._dynamo.test_case.TestCase):
|
|
"""
|
|
Unit tests for the DeviceGuard class using a CudaInterface.
|
|
"""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.device_interface = CudaInterface
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
|
|
def test_device_guard(self):
|
|
current_device = torch.cuda.current_device()
|
|
|
|
device_guard = DeviceGuard(self.device_interface, 1)
|
|
|
|
with device_guard as _:
|
|
self.assertEqual(torch.cuda.current_device(), 1)
|
|
self.assertEqual(device_guard.prev_idx, 0)
|
|
self.assertEqual(device_guard.idx, 1)
|
|
|
|
self.assertEqual(torch.cuda.current_device(), current_device)
|
|
self.assertEqual(device_guard.prev_idx, 0)
|
|
self.assertEqual(device_guard.idx, 0)
|
|
|
|
def test_device_guard_no_index(self):
|
|
current_device = torch.cuda.current_device()
|
|
|
|
device_guard = DeviceGuard(self.device_interface, None)
|
|
|
|
with device_guard as _:
|
|
self.assertEqual(torch.cuda.current_device(), current_device)
|
|
self.assertEqual(device_guard.prev_idx, -1)
|
|
self.assertEqual(device_guard.idx, None)
|
|
|
|
self.assertEqual(device_guard.prev_idx, -1)
|
|
self.assertEqual(device_guard.idx, None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|