update mps note with more details (#78669)

Follow up to the comments in https://github.com/pytorch/pytorch/pull/77767#pullrequestreview-978807521
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78669
Approved by: https://github.com/kulinseth, https://github.com/anjali411
This commit is contained in:
albanD
2022-06-02 20:53:19 +00:00
committed by PyTorch MergeBot
parent 3e0f1a8a32
commit b30b1f3dec

View File

@ -14,27 +14,31 @@ capabilities to setup and run operations on GPU.
To get started, simply move your Tensor and Module to the ``mps`` device:
.. code::
.. code:: python
# Make sure the current PyTorch binary was built with MPS enabled
print(torch.backends.mps.is_built())
# And that the current hardware and MacOS version are sufficient to
# be able to use MPS
print(torch.backends.mps.is_available())
# Check that MPS is available
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
mps_device = torch.device("mps")
else:
mps_device = torch.device("mps")
# Create a Tensor directly on the mps device
x = torch.ones(5, device=mps_device)
# Or
x = torch.ones(5, device="mps")
# Create a Tensor directly on the mps device
x = torch.ones(5, device=mps_device)
# Or
x = torch.ones(5, device="mps")
# Any operation happens on the GPU
y = x * 2
# Any operation happens on the GPU
y = x * 2
# Move your model to mps just like any other device
model = YourFavoriteNet()
model.to(mps_device)
# Move your model to mps just like any other device
model = YourFavoriteNet()
model.to(mps_device)
# Now every call runs on the GPU
pred = model(x)
# Now every call runs on the GPU
pred = model(x)