Conform torch.mps to device module interface (#124676)

Right now `torch.fork_rng()` doesn't support MPS. MPS' device module functions don't line up with the others'. There is a step of `fork_rng` to call `device_count()`:

302d7e9a6e/torch/random.py (L146)

It is pretty simple to know the MPS device count, based on whether it is built and available.

Also:

302d7e9a6e/torch/random.py (L168)

302d7e9a6e/torch/random.py (L175)

`get_rng_state` and `set_rng_state` are expected to be able to accept a `device` parameter.

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124676
Approved by: https://github.com/ezyang
This commit is contained in:
Matthew Hoffman
2024-04-23 18:38:44 +00:00
committed by PyTorch MergeBot
parent 4e66aaa010
commit 1d3a13d3d1
2 changed files with 22 additions and 4 deletions

View File

@ -7,6 +7,7 @@ torch.mps
:toctree: generated
:nosignatures:
device_count
synchronize
get_rng_state
set_rng_state

View File

@ -4,6 +4,8 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using
performance can be achieved, by running work on the metal GPU(s).
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
"""
from typing import Union
import torch
from .. import Tensor
@ -19,21 +21,35 @@ def _get_default_mps_generator() -> torch._C.Generator:
return _default_mps_generator
def device_count() -> int:
r"""Returns the number of available MPS devices."""
return int(torch._C._has_mps and torch._C._mps_is_available())
def synchronize() -> None:
r"""Waits for all kernels in all streams on a MPS device to complete."""
return torch._C._mps_deviceSynchronize()
def get_rng_state() -> Tensor:
r"""Returns the random number generator state as a ByteTensor."""
def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
r"""Returns the random number generator state as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
"""
return _get_default_mps_generator().get_state()
def set_rng_state(new_state: Tensor) -> None:
def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "mps"
) -> None:
r"""Sets the random number generator state.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
"""
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
_get_default_mps_generator().set_state(new_state_copy)
@ -116,6 +132,7 @@ from . import profiler
from .event import Event
__all__ = [
"device_count",
"get_rng_state",
"manual_seed",
"seed",