Document CUDA best practices (#3227)

This commit is contained in:
Kai Arulkumaran
2017-10-25 21:38:17 +01:00
committed by Adam Paszke
parent 837f933cac
commit a7c5be1d45

View File

@ -44,6 +44,82 @@ Below you can find a small example showcasing this::
Best practices
--------------
Device-agnostic code
^^^^^^^^^^^^^^^^^^^^
Due to the structure of PyTorch, you may need to explicitly write
device-agnostic (CPU or GPU) code; an example may be creating a new tensor as
the initial hidden state of a recurrent neural network.
The first step is to determine whether the GPU should be used or not. A common
pattern is to use Python's `argparse` module to read in user arguments, and
have a flag that can be used to disable CUDA, in combination with
`torch.cuda.is_available()`. In the following, `args.cuda` results in a flag
that can be used to cast tensors and modules to CUDA if desired::
import argparse
import torch
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--disable-cuda', action='store_true',
help='Disable CUDA')
args = parser.parse_args()
args.cuda = not args.disable_cuda and torch.cuda.is_available()
If modules or tensors need to be sent to the GPU, `args.cuda` can be used as
follows::
x = torch.Tensor(8, 42)
net = Network()
if args.cuda:
x = x.cuda()
net.cuda()
When creating tensors, an alternative to the if statement is to have a default
datatype defined, and cast all tensors using that. An example when using a
dataloader would be as follows::
dtype = torch.cuda.FloatTensor
for i, x in enumerate(train_loader):
x = Variable(x.type(dtype))
When working with multiple GPUs on a system, you can use the
`CUDA_VISIBLE_DEVICES` environment flag to manage which GPUs are available to
PyTorch. To manually control which GPU a tensor is created on, the best practice
is to use the `torch.cuda.device()` context manager::
print("Outside device is 0") # On device 0 (default in most scenarios)
with torch.cuda.device(1):
print("Inside device is 1") # On device 1
print("Outside device is still 0") # On device 0
If you have a tensor and would like to create a new tensor of the same type on
the same device, then you can use the `.new()` function, which acts the same as
a normal tensor constructor. Whilst the previously mentioned methods depend on
the current GPU context, `new()` preserves the device of the original tensor.
This is the recommended practice when creating modules in which new
tensors/variables need to be created internally during the forward pass::
x_cpu = torch.FloatTensor(1)
x_gpu = torch.cuda.FloatTensor(1)
x_cpu_long = torch.LongTensor(1)
y_cpu = x_cpu.new(8, 10, 10).fill_(0.3)
y_gpu = x_gpu.new(x_gpu.size()).fill_(-5)
y_cpu_long = x_cpu_long.new([[1, 2, 3]])
If you want to create a tensor of the same type and size of another tensor, and
fill it with either ones or zeros, `torch.ones_like()` or `torch.zeros_like()`
are provided as more convenient functions (which also preserve device)::
x_cpu = torch.FloatTensor(1)
x_gpu = torch.cuda.FloatTensor(1)
y_cpu = torch.ones_like(x_cpu)
y_gpu = torch.zeros_like(x_gpu)
Use pinned memory buffers
^^^^^^^^^^^^^^^^^^^^^^^^^