mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
update get_default_device to also respect torch.device ctx manager (#148621)
Fixes https://github.com/pytorch/pytorch/issues/131328 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148621 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
db491825e0
commit
694028f502
@ -1060,18 +1060,33 @@ class TestDeviceUtils(TestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
def test_get_default_device_more(self):
|
def test_get_default_device_more(self):
|
||||||
torch.set_default_device("cuda")
|
try:
|
||||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_device(None)
|
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||||
|
torch.set_default_device(None)
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.set_device("cuda:1")
|
torch.cuda.set_device("cuda:1")
|
||||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||||
torch.set_default_device(None)
|
torch.set_default_device(None)
|
||||||
|
|
||||||
torch.set_default_device("cuda:1")
|
torch.set_default_device("cuda:1")
|
||||||
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
|
||||||
torch.set_default_device(None)
|
torch.set_default_device(None)
|
||||||
|
|
||||||
|
torch.set_default_device("cuda:1")
|
||||||
|
with torch.device("cuda:0"):
|
||||||
|
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
|
||||||
|
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
|
||||||
|
with torch.device("cuda:0"):
|
||||||
|
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
|
||||||
|
|
||||||
|
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
|
||||||
|
finally:
|
||||||
|
# Reset the device at the end.
|
||||||
|
torch.set_default_device(None)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@ops(op_db)
|
@ops(op_db)
|
||||||
|
@ -1159,14 +1159,32 @@ def get_default_device() -> "torch.device":
|
|||||||
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
||||||
global _GLOBAL_DEVICE_CONTEXT
|
global _GLOBAL_DEVICE_CONTEXT
|
||||||
|
|
||||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
from torch.overrides import _get_current_function_mode_stack
|
||||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
from torch.utils._device import DeviceContext
|
||||||
|
|
||||||
|
def _get_device_with_index(device):
|
||||||
if device.index is not None:
|
if device.index is not None:
|
||||||
return device
|
return device
|
||||||
else:
|
else:
|
||||||
# TODO: Call like get_device_index() method corresponding to
|
# TODO: Call like get_device_index() method corresponding to
|
||||||
# each device type
|
# each device type
|
||||||
return torch.tensor([]).device
|
return torch.tensor([]).device
|
||||||
|
|
||||||
|
# Get device from any active DeviceContext.
|
||||||
|
device_mode = next(
|
||||||
|
filter(
|
||||||
|
lambda mode: isinstance(mode, DeviceContext),
|
||||||
|
reversed(_get_current_function_mode_stack()),
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if device_mode:
|
||||||
|
device = device_mode.device
|
||||||
|
return _get_device_with_index(device)
|
||||||
|
|
||||||
|
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||||
|
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||||
|
return _get_device_with_index(device)
|
||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user