mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Replace _device_t with torch.types.Device in torch/cpu/__init__.py (#161031)
Fixes #152952 Replace `_device_t` with `torch.types.Device` in `torch/cpu/__init__.py`. Did basic smoke test by running tests that `import torch.cpu` including `test/distributed/test_c10d_functional_native.py` and `test/test_decomp.py`. Based this PR off of #152935 which is referenced in the main issue. (also, this is my first contribution but I followed the contributing guide closely) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161031 Approved by: https://github.com/janeyx99
This commit is contained in:
@ -27,8 +27,6 @@ __all__ = [
|
||||
"Event",
|
||||
]
|
||||
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
def _is_avx2_supported() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports AVX2."""
|
||||
@ -75,7 +73,7 @@ def is_available() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def synchronize(device: _device_t = None) -> None:
|
||||
def synchronize(device: torch.types.Device = None) -> None:
|
||||
r"""Waits for all kernels in all streams on the CPU device to complete.
|
||||
|
||||
Args:
|
||||
@ -121,7 +119,7 @@ _default_cpu_stream = Stream()
|
||||
_current_stream = _default_cpu_stream
|
||||
|
||||
|
||||
def current_stream(device: _device_t = None) -> Stream:
|
||||
def current_stream(device: torch.types.Device = None) -> Stream:
|
||||
r"""Returns the currently selected :class:`Stream` for a given device.
|
||||
|
||||
Args:
|
||||
@ -181,7 +179,7 @@ def device_count() -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def set_device(device: _device_t) -> None:
|
||||
def set_device(device: torch.types.Device) -> None:
|
||||
r"""Sets the current device, in CPU we do nothing.
|
||||
|
||||
N.B. This function only exists to facilitate device-agnostic code
|
||||
|
Reference in New Issue
Block a user