Replace _device_t with torch.types.Device in torch/cpu/__init__.py (#161031)

Fixes #152952

Replace `_device_t` with `torch.types.Device` in `torch/cpu/__init__.py`. Did basic smoke test by running tests that `import torch.cpu` including `test/distributed/test_c10d_functional_native.py` and `test/test_decomp.py`.

Based this PR off of #152935 which is referenced in the main issue.

(also, this is my first contribution but I followed the contributing guide closely)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161031
Approved by: https://github.com/janeyx99
This commit is contained in:
Grant
2025-08-21 00:22:40 +00:00
committed by PyTorch MergeBot
parent be87f22dfb
commit 54c2b66592

View File

@ -27,8 +27,6 @@ __all__ = [
"Event",
]
_device_t = Union[_device, str, int, None]
def _is_avx2_supported() -> bool:
r"""Returns a bool indicating if CPU supports AVX2."""
@ -75,7 +73,7 @@ def is_available() -> bool:
return True
def synchronize(device: _device_t = None) -> None:
def synchronize(device: torch.types.Device = None) -> None:
r"""Waits for all kernels in all streams on the CPU device to complete.
Args:
@ -121,7 +119,7 @@ _default_cpu_stream = Stream()
_current_stream = _default_cpu_stream
def current_stream(device: _device_t = None) -> Stream:
def current_stream(device: torch.types.Device = None) -> Stream:
r"""Returns the currently selected :class:`Stream` for a given device.
Args:
@ -181,7 +179,7 @@ def device_count() -> int:
return 1
def set_device(device: _device_t) -> None:
def set_device(device: torch.types.Device) -> None:
r"""Sets the current device, in CPU we do nothing.
N.B. This function only exists to facilitate device-agnostic code