mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
update get_default_device to also respect torch.device ctx manager (#148621)
Fixes https://github.com/pytorch/pytorch/issues/131328 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148621 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
db491825e0
commit
694028f502
@ -1060,18 +1060,33 @@ class TestDeviceUtils(TestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||
def test_get_default_device_more(self):
|
||||
torch.set_default_device("cuda")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
try:
|
||||
torch.set_default_device("cuda")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.set_device("cuda:1")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.set_device("cuda:1")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
|
||||
torch.set_default_device("cuda:1")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
torch.set_default_device("cuda:1")
|
||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||
torch.set_default_device(None)
|
||||
|
||||
torch.set_default_device("cuda:1")
|
||||
with torch.device("cuda:0"):
|
||||
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
|
||||
|
||||
torch.set_default_device("cpu")
|
||||
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
|
||||
with torch.device("cuda:0"):
|
||||
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
|
||||
|
||||
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
|
||||
finally:
|
||||
# Reset the device at the end.
|
||||
torch.set_default_device(None)
|
||||
|
||||
@onlyCPU
|
||||
@ops(op_db)
|
||||
|
||||
@ -1159,14 +1159,32 @@ def get_default_device() -> "torch.device":
|
||||
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
||||
global _GLOBAL_DEVICE_CONTEXT
|
||||
|
||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||
from torch.overrides import _get_current_function_mode_stack
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
def _get_device_with_index(device):
|
||||
if device.index is not None:
|
||||
return device
|
||||
else:
|
||||
# TODO: Call like get_device_index() method corresponding to
|
||||
# each device type
|
||||
return torch.tensor([]).device
|
||||
|
||||
# Get device from any active DeviceContext.
|
||||
device_mode = next(
|
||||
filter(
|
||||
lambda mode: isinstance(mode, DeviceContext),
|
||||
reversed(_get_current_function_mode_stack()),
|
||||
),
|
||||
None,
|
||||
)
|
||||
if device_mode:
|
||||
device = device_mode.device
|
||||
return _get_device_with_index(device)
|
||||
|
||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||
return _get_device_with_index(device)
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user