mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
update get_default_device to also respect torch.device ctx manager (#148621)
Fixes https://github.com/pytorch/pytorch/issues/131328 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148621 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
db491825e0
commit
694028f502
@ -1159,14 +1159,32 @@ def get_default_device() -> "torch.device":
|
||||
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
|
||||
global _GLOBAL_DEVICE_CONTEXT
|
||||
|
||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||
from torch.overrides import _get_current_function_mode_stack
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
def _get_device_with_index(device):
|
||||
if device.index is not None:
|
||||
return device
|
||||
else:
|
||||
# TODO: Call like get_device_index() method corresponding to
|
||||
# each device type
|
||||
return torch.tensor([]).device
|
||||
|
||||
# Get device from any active DeviceContext.
|
||||
device_mode = next(
|
||||
filter(
|
||||
lambda mode: isinstance(mode, DeviceContext),
|
||||
reversed(_get_current_function_mode_stack()),
|
||||
),
|
||||
None,
|
||||
)
|
||||
if device_mode:
|
||||
device = device_mode.device
|
||||
return _get_device_with_index(device)
|
||||
|
||||
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
|
||||
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
|
||||
return _get_device_with_index(device)
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
Reference in New Issue
Block a user