mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Document CUDA best practices (#3227)
This commit is contained in:
committed by
Adam Paszke
parent
837f933cac
commit
a7c5be1d45
@ -44,6 +44,82 @@ Below you can find a small example showcasing this::
|
||||
Best practices
|
||||
--------------
|
||||
|
||||
Device-agnostic code
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Due to the structure of PyTorch, you may need to explicitly write
|
||||
device-agnostic (CPU or GPU) code; an example may be creating a new tensor as
|
||||
the initial hidden state of a recurrent neural network.
|
||||
|
||||
The first step is to determine whether the GPU should be used or not. A common
|
||||
pattern is to use Python's `argparse` module to read in user arguments, and
|
||||
have a flag that can be used to disable CUDA, in combination with
|
||||
`torch.cuda.is_available()`. In the following, `args.cuda` results in a flag
|
||||
that can be used to cast tensors and modules to CUDA if desired::
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Example')
|
||||
parser.add_argument('--disable-cuda', action='store_true',
|
||||
help='Disable CUDA')
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
||||
|
||||
If modules or tensors need to be sent to the GPU, `args.cuda` can be used as
|
||||
follows::
|
||||
|
||||
x = torch.Tensor(8, 42)
|
||||
net = Network()
|
||||
if args.cuda:
|
||||
x = x.cuda()
|
||||
net.cuda()
|
||||
|
||||
When creating tensors, an alternative to the if statement is to have a default
|
||||
datatype defined, and cast all tensors using that. An example when using a
|
||||
dataloader would be as follows::
|
||||
|
||||
dtype = torch.cuda.FloatTensor
|
||||
for i, x in enumerate(train_loader):
|
||||
x = Variable(x.type(dtype))
|
||||
|
||||
When working with multiple GPUs on a system, you can use the
|
||||
`CUDA_VISIBLE_DEVICES` environment flag to manage which GPUs are available to
|
||||
PyTorch. To manually control which GPU a tensor is created on, the best practice
|
||||
is to use the `torch.cuda.device()` context manager::
|
||||
|
||||
print("Outside device is 0") # On device 0 (default in most scenarios)
|
||||
with torch.cuda.device(1):
|
||||
print("Inside device is 1") # On device 1
|
||||
print("Outside device is still 0") # On device 0
|
||||
|
||||
If you have a tensor and would like to create a new tensor of the same type on
|
||||
the same device, then you can use the `.new()` function, which acts the same as
|
||||
a normal tensor constructor. Whilst the previously mentioned methods depend on
|
||||
the current GPU context, `new()` preserves the device of the original tensor.
|
||||
|
||||
This is the recommended practice when creating modules in which new
|
||||
tensors/variables need to be created internally during the forward pass::
|
||||
|
||||
x_cpu = torch.FloatTensor(1)
|
||||
x_gpu = torch.cuda.FloatTensor(1)
|
||||
x_cpu_long = torch.LongTensor(1)
|
||||
|
||||
y_cpu = x_cpu.new(8, 10, 10).fill_(0.3)
|
||||
y_gpu = x_gpu.new(x_gpu.size()).fill_(-5)
|
||||
y_cpu_long = x_cpu_long.new([[1, 2, 3]])
|
||||
|
||||
If you want to create a tensor of the same type and size of another tensor, and
|
||||
fill it with either ones or zeros, `torch.ones_like()` or `torch.zeros_like()`
|
||||
are provided as more convenient functions (which also preserve device)::
|
||||
|
||||
x_cpu = torch.FloatTensor(1)
|
||||
x_gpu = torch.cuda.FloatTensor(1)
|
||||
|
||||
y_cpu = torch.ones_like(x_cpu)
|
||||
y_gpu = torch.zeros_like(x_gpu)
|
||||
|
||||
|
||||
Use pinned memory buffers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
Reference in New Issue
Block a user