update mps note with more details (#78669)

Follow up to the comments in https://github.com/pytorch/pytorch/pull/77767#pullrequestreview-978807521
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78669
Approved by: https://github.com/kulinseth, https://github.com/anjali411
This commit is contained in:
albanD
2022-06-02 20:53:19 +00:00
committed by PyTorch MergeBot
parent 3e0f1a8a32
commit b30b1f3dec

View File

@ -14,27 +14,31 @@ capabilities to setup and run operations on GPU.
To get started, simply move your Tensor and Module to the ``mps`` device: To get started, simply move your Tensor and Module to the ``mps`` device:
.. code:: .. code:: python
# Make sure the current PyTorch binary was built with MPS enabled # Check that MPS is available
print(torch.backends.mps.is_built()) if not torch.backends.mps.is_available():
# And that the current hardware and MacOS version are sufficient to if not torch.backends.mps.is_built():
# be able to use MPS print("MPS not available because the current PyTorch install was not "
print(torch.backends.mps.is_available()) "built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
mps_device = torch.device("mps") else:
mps_device = torch.device("mps")
# Create a Tensor directly on the mps device # Create a Tensor directly on the mps device
x = torch.ones(5, device=mps_device) x = torch.ones(5, device=mps_device)
# Or # Or
x = torch.ones(5, device="mps") x = torch.ones(5, device="mps")
# Any operation happens on the GPU # Any operation happens on the GPU
y = x * 2 y = x * 2
# Move your model to mps just like any other device # Move your model to mps just like any other device
model = YourFavoriteNet() model = YourFavoriteNet()
model.to(mps_device) model.to(mps_device)
# Now every call runs on the GPU # Now every call runs on the GPU
pred = model(x) pred = model(x)