[Typing] Improve device typing for torch.set_default_device() (#153028)

Part of: #152952

Here is the definition of `torch.types.Device`:

ab997d9ff5/torch/types.py (L74)

So `_Optional[_Union["torch.device", str, builtins.int]]` is equivalent to it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153028
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanhao Ji
2025-05-07 19:31:39 +00:00
committed by PyTorch MergeBot
parent dd7d231ed3
commit f5f8f637a5

View File

@ -36,7 +36,7 @@ from typing_extensions import ParamSpec as _ParamSpec
if TYPE_CHECKING:
from .types import IntLikeType
from .types import Device, IntLikeType
# multipy/deploy is setting this import before importing torch, this is the most
@ -1154,9 +1154,7 @@ def get_default_device() -> "torch.device":
return torch.device("cpu")
def set_default_device(
device: _Optional[_Union["torch.device", str, builtins.int]],
) -> None:
def set_default_device(device: "Device") -> None:
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
does not affect factory function calls which are called with an explicit
``device`` argument. Factory calls will be performed as if they