mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
update mps note with more details (#78669)
Follow up to the comments in https://github.com/pytorch/pytorch/pull/77767#pullrequestreview-978807521 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78669 Approved by: https://github.com/kulinseth, https://github.com/anjali411
This commit is contained in:
@ -14,27 +14,31 @@ capabilities to setup and run operations on GPU.
|
|||||||
|
|
||||||
To get started, simply move your Tensor and Module to the ``mps`` device:
|
To get started, simply move your Tensor and Module to the ``mps`` device:
|
||||||
|
|
||||||
.. code::
|
.. code:: python
|
||||||
|
|
||||||
# Make sure the current PyTorch binary was built with MPS enabled
|
# Check that MPS is available
|
||||||
print(torch.backends.mps.is_built())
|
if not torch.backends.mps.is_available():
|
||||||
# And that the current hardware and MacOS version are sufficient to
|
if not torch.backends.mps.is_built():
|
||||||
# be able to use MPS
|
print("MPS not available because the current PyTorch install was not "
|
||||||
print(torch.backends.mps.is_available())
|
"built with MPS enabled.")
|
||||||
|
else:
|
||||||
|
print("MPS not available because the current MacOS version is not 12.3+ "
|
||||||
|
"and/or you do not have an MPS-enabled device on this machine.")
|
||||||
|
|
||||||
mps_device = torch.device("mps")
|
else:
|
||||||
|
mps_device = torch.device("mps")
|
||||||
|
|
||||||
# Create a Tensor directly on the mps device
|
# Create a Tensor directly on the mps device
|
||||||
x = torch.ones(5, device=mps_device)
|
x = torch.ones(5, device=mps_device)
|
||||||
# Or
|
# Or
|
||||||
x = torch.ones(5, device="mps")
|
x = torch.ones(5, device="mps")
|
||||||
|
|
||||||
# Any operation happens on the GPU
|
# Any operation happens on the GPU
|
||||||
y = x * 2
|
y = x * 2
|
||||||
|
|
||||||
# Move your model to mps just like any other device
|
# Move your model to mps just like any other device
|
||||||
model = YourFavoriteNet()
|
model = YourFavoriteNet()
|
||||||
model.to(mps_device)
|
model.to(mps_device)
|
||||||
|
|
||||||
# Now every call runs on the GPU
|
# Now every call runs on the GPU
|
||||||
pred = model(x)
|
pred = model(x)
|
||||||
|
Reference in New Issue
Block a user