update get_default_device to also respect torch.device ctx manager (#148621)

Fixes https://github.com/pytorch/pytorch/issues/131328
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148621
Approved by: https://github.com/ezyang
This commit is contained in:
Kshiteej K
2025-06-07 14:26:17 +00:00
committed by PyTorch MergeBot
parent db491825e0
commit 694028f502
2 changed files with 45 additions and 12 deletions

View File

@ -1159,14 +1159,32 @@ def get_default_device() -> "torch.device":
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
global _GLOBAL_DEVICE_CONTEXT
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
from torch.overrides import _get_current_function_mode_stack
from torch.utils._device import DeviceContext
def _get_device_with_index(device):
if device.index is not None:
return device
else:
# TODO: Call like get_device_index() method corresponding to
# each device type
return torch.tensor([]).device
# Get device from any active DeviceContext.
device_mode = next(
filter(
lambda mode: isinstance(mode, DeviceContext),
reversed(_get_current_function_mode_stack()),
),
None,
)
if device_mode:
device = device_mode.device
return _get_device_with_index(device)
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
return _get_device_with_index(device)
else:
return torch.device("cpu")