mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Conform torch.mps to device module interface (#124676)
Right now `torch.fork_rng()` doesn't support MPS. MPS' device module functions don't line up with the others'. There is a step of `fork_rng` to call `device_count()`:302d7e9a6e/torch/random.py (L146)
It is pretty simple to know the MPS device count, based on whether it is built and available. Also:302d7e9a6e/torch/random.py (L168)
302d7e9a6e/torch/random.py (L175)
`get_rng_state` and `set_rng_state` are expected to be able to accept a `device` parameter. @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/124676 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e66aaa010
commit
1d3a13d3d1
@ -7,6 +7,7 @@ torch.mps
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
device_count
|
||||
synchronize
|
||||
get_rng_state
|
||||
set_rng_state
|
||||
|
@ -4,6 +4,8 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using
|
||||
performance can be achieved, by running work on the metal GPU(s).
|
||||
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from .. import Tensor
|
||||
|
||||
@ -19,21 +21,35 @@ def _get_default_mps_generator() -> torch._C.Generator:
|
||||
return _default_mps_generator
|
||||
|
||||
|
||||
def device_count() -> int:
|
||||
r"""Returns the number of available MPS devices."""
|
||||
return int(torch._C._has_mps and torch._C._mps_is_available())
|
||||
|
||||
|
||||
def synchronize() -> None:
|
||||
r"""Waits for all kernels in all streams on a MPS device to complete."""
|
||||
return torch._C._mps_deviceSynchronize()
|
||||
|
||||
|
||||
def get_rng_state() -> Tensor:
|
||||
r"""Returns the random number generator state as a ByteTensor."""
|
||||
def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
|
||||
r"""Returns the random number generator state as a ByteTensor.
|
||||
|
||||
Args:
|
||||
device (torch.device or int, optional): The device to return the RNG state of.
|
||||
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
|
||||
"""
|
||||
return _get_default_mps_generator().get_state()
|
||||
|
||||
|
||||
def set_rng_state(new_state: Tensor) -> None:
|
||||
def set_rng_state(
|
||||
new_state: Tensor, device: Union[int, str, torch.device] = "mps"
|
||||
) -> None:
|
||||
r"""Sets the random number generator state.
|
||||
|
||||
Args:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
device (torch.device or int, optional): The device to set the RNG state.
|
||||
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
|
||||
"""
|
||||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
||||
_get_default_mps_generator().set_state(new_state_copy)
|
||||
@ -116,6 +132,7 @@ from . import profiler
|
||||
from .event import Event
|
||||
|
||||
__all__ = [
|
||||
"device_count",
|
||||
"get_rng_state",
|
||||
"manual_seed",
|
||||
"seed",
|
||||
|
Reference in New Issue
Block a user