diff --git a/docs/source/notes/mps.rst b/docs/source/notes/mps.rst index 6ad44ba97714..cc2a781b590e 100644 --- a/docs/source/notes/mps.rst +++ b/docs/source/notes/mps.rst @@ -14,27 +14,31 @@ capabilities to setup and run operations on GPU. To get started, simply move your Tensor and Module to the ``mps`` device: -.. code:: +.. code:: python - # Make sure the current PyTorch binary was built with MPS enabled - print(torch.backends.mps.is_built()) - # And that the current hardware and MacOS version are sufficient to - # be able to use MPS - print(torch.backends.mps.is_available()) + # Check that MPS is available + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + print("MPS not available because the current PyTorch install was not " + "built with MPS enabled.") + else: + print("MPS not available because the current MacOS version is not 12.3+ " + "and/or you do not have an MPS-enabled device on this machine.") - mps_device = torch.device("mps") + else: + mps_device = torch.device("mps") - # Create a Tensor directly on the mps device - x = torch.ones(5, device=mps_device) - # Or - x = torch.ones(5, device="mps") + # Create a Tensor directly on the mps device + x = torch.ones(5, device=mps_device) + # Or + x = torch.ones(5, device="mps") - # Any operation happens on the GPU - y = x * 2 + # Any operation happens on the GPU + y = x * 2 - # Move your model to mps just like any other device - model = YourFavoriteNet() - model.to(mps_device) + # Move your model to mps just like any other device + model = YourFavoriteNet() + model.to(mps_device) - # Now every call runs on the GPU - pred = model(x) + # Now every call runs on the GPU + pred = model(x)