mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Typing] Improve device typing for torch.set_default_device()
(#153028)
Part of: #152952
Here is the definition of `torch.types.Device`:
ab997d9ff5/torch/types.py (L74)
So `_Optional[_Union["torch.device", str, builtins.int]]` is equivalent to it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153028
Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd7d231ed3
commit
f5f8f637a5
@ -36,7 +36,7 @@ from typing_extensions import ParamSpec as _ParamSpec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types import IntLikeType
|
||||
from .types import Device, IntLikeType
|
||||
|
||||
|
||||
# multipy/deploy is setting this import before importing torch, this is the most
|
||||
@ -1154,9 +1154,7 @@ def get_default_device() -> "torch.device":
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def set_default_device(
|
||||
device: _Optional[_Union["torch.device", str, builtins.int]],
|
||||
) -> None:
|
||||
def set_default_device(device: "Device") -> None:
|
||||
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
||||
does not affect factory function calls which are called with an explicit
|
||||
``device`` argument. Factory calls will be performed as if they
|
||||
|
Reference in New Issue
Block a user