update get_default_device to also respect torch.device ctx manager (#148621)

Fixes https://github.com/pytorch/pytorch/issues/131328
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148621
Approved by: https://github.com/ezyang
This commit is contained in:
Kshiteej K
2025-06-07 14:26:17 +00:00
committed by PyTorch MergeBot
parent db491825e0
commit 694028f502
2 changed files with 45 additions and 12 deletions

View File

@ -1060,18 +1060,33 @@ class TestDeviceUtils(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_get_default_device_more(self): def test_get_default_device_more(self):
torch.set_default_device("cuda") try:
self.assertEqual(torch.get_default_device(), torch.tensor([]).device) torch.set_default_device("cuda")
torch.set_default_device(None) self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None)
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.set_device("cuda:1") torch.cuda.set_device("cuda:1")
self.assertEqual(torch.get_default_device(), torch.tensor([]).device) self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None) torch.set_default_device(None)
torch.set_default_device("cuda:1") torch.set_default_device("cuda:1")
self.assertEqual(torch.get_default_device(), torch.tensor([]).device) self.assertEqual(torch.get_default_device(), torch.tensor([]).device)
torch.set_default_device(None) torch.set_default_device(None)
torch.set_default_device("cuda:1")
with torch.device("cuda:0"):
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
torch.set_default_device("cpu")
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
with torch.device("cuda:0"):
self.assertEqual(torch.get_default_device(), torch.device("cuda", 0))
self.assertEqual(torch.get_default_device(), torch.device("cpu"))
finally:
# Reset the device at the end.
torch.set_default_device(None)
@onlyCPU @onlyCPU
@ops(op_db) @ops(op_db)

View File

@ -1159,14 +1159,32 @@ def get_default_device() -> "torch.device":
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``""" r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
global _GLOBAL_DEVICE_CONTEXT global _GLOBAL_DEVICE_CONTEXT
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): from torch.overrides import _get_current_function_mode_stack
device = _GLOBAL_DEVICE_CONTEXT.device_context.device from torch.utils._device import DeviceContext
def _get_device_with_index(device):
if device.index is not None: if device.index is not None:
return device return device
else: else:
# TODO: Call like get_device_index() method corresponding to # TODO: Call like get_device_index() method corresponding to
# each device type # each device type
return torch.tensor([]).device return torch.tensor([]).device
# Get device from any active DeviceContext.
device_mode = next(
filter(
lambda mode: isinstance(mode, DeviceContext),
reversed(_get_current_function_mode_stack()),
),
None,
)
if device_mode:
device = device_mode.device
return _get_device_with_index(device)
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
return _get_device_with_index(device)
else: else:
return torch.device("cpu") return torch.device("cpu")