pin_memory malloc now uses existing context if available. (#22229)

Summary:
This is achieved by using `cuDevicePrimaryCtxGetState` as a way to check whether a primary context exists on a device. It is not too slow, from this benchmark of a single call to it on CUDA 10.1, Titan Xp, driver 415.27:
```
---------------------------------------------------------------------
Benchmark                              Time           CPU Iterations
---------------------------------------------------------------------
BM_cuDevicePrimaryCtxGetState        301 ns        301 ns    2319746
```

Commits:

1. Add `CUDAHooks::getDeviceWithPrimaryContext` which returns a device index with primary context (if exists).
    Link `c10/cuda` against `libcuda` for device API calls.
2. Use `getDeviceWithPrimaryContext` to check primary context in `pin_memory`.
    Fix `OptionalDeviceGuard` doc.
3. Refactor `test_cuda_primary_ctx.py` to support multiple tests.
    Add test for this in that file.

Fixes https://github.com/pytorch/pytorch/issues/21081.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22229

Differential Revision: D16170194

Pulled By: zou3519

fbshipit-source-id: 485a45f211b7844c9e69c63f3b3b75194a796c5d
This commit is contained in:
SsnL
2019-07-16 10:05:53 -07:00
committed by Facebook Github Bot
parent 054c7eb0f4
commit 8482efb203
12 changed files with 210 additions and 85 deletions

View File

@ -5,6 +5,7 @@
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/detail/CUDAHooksInterface.h>
@ -114,6 +115,15 @@ int64_t CUDAHooks::current_device() const {
return -1;
}
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
int ctx_is_active;
AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
return ctx_is_active == 1;
}
Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
return at::cuda::getPinnedMemoryAllocator();
}

View File

@ -18,6 +18,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuDNN() const override;
const at::cuda::NVRTC& nvrtc() const override;
int64_t current_device() const override;
bool hasPrimaryContext(int64_t device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool compiledWithCuDNN() const override;
bool compiledWithMIOpen() const override;

View File

@ -57,7 +57,8 @@ namespace at { namespace cuda {
//
// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both
// NVRTC and driver APIs. While the former is not yet suppoted for HIP, the
// later is supported and needed.
// later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext()
// used by tensor.pin_memory()).
//
// The macro below strips out certain unsupported operations on HIP from the full
// list above.

View File

@ -60,15 +60,15 @@ struct CAFFE2_API CUDAHooksInterface {
// Initialize THCState and, transitively, the CUDA state
virtual std::unique_ptr<THCState, void (*)(THCState*)> initCUDA() const {
AT_ERROR("Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
}
virtual Generator* getDefaultCUDAGenerator(DeviceIndex device_index = -1) const {
AT_ERROR("Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
}
virtual Device getDeviceFromPtr(void* data) const {
AT_ERROR("Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
}
virtual bool hasCUDA() const {
@ -84,15 +84,19 @@ struct CAFFE2_API CUDAHooksInterface {
}
virtual const at::cuda::NVRTC& nvrtc() const {
AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP);
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}
virtual int64_t current_device() const {
return -1;
}
virtual bool hasPrimaryContext(int64_t device_index) const {
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
}
virtual Allocator* getPinnedMemoryAllocator() const {
AT_ERROR("Pinned memory requires CUDA. ", CUDA_HELP);
TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
}
virtual bool compiledWithCuDNN() const {
@ -112,32 +116,32 @@ struct CAFFE2_API CUDAHooksInterface {
}
virtual long versionCuDNN() const {
AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual std::string showConfig() const {
AT_ERROR("Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
}
virtual double batchnormMinEpsilonCuDNN() const {
AT_ERROR(
TORCH_CHECK(false,
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
}
virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual int64_t cuFFTGetPlanCacheSize(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual void cuFFTClearPlanCache(int64_t device_index) const {
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
}
virtual int getNumGPUs() const {

View File

@ -1,10 +1,10 @@
#include <ATen/ATen.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/Storage.h>
#include <ATen/TensorUtils.h>
namespace at {
namespace native {

View File

@ -1,4 +1,6 @@
#include <THC/THCCachingHostAllocator.h>
#include <ATen/DeviceGuard.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <cuda_runtime_api.h>
@ -40,6 +42,24 @@ static bool BlockComparator(const BlockSize& a, const BlockSize& b)
return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
}
static int64_t inline get_device_index_with_primary_context() {
const auto& cuda_hooks = at::detail::getCUDAHooks();
// check current device first
int64_t current_device_index = cuda_hooks.current_device();
if (current_device_index >= 0) {
if (cuda_hooks.hasPrimaryContext(current_device_index)) {
return current_device_index;
}
}
for (int64_t device_index = 0; device_index < cuda_hooks.getNumGPUs(); device_index++) {
if (device_index == current_device_index) continue;
if (cuda_hooks.hasPrimaryContext(device_index)) {
return device_index;
}
}
return -1;
}
struct HostAllocator
{
typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
@ -80,6 +100,17 @@ struct HostAllocator
return cudaSuccess;
}
// Pinned memory pointers allocated by any device can be directly used by any
// other device, regardless of the current device at the time of allocation,
// since we assume unified addressing.
// So we grab any existing primary context, if available.
// See pytorch/pytorch#21081.
at::OptionalDeviceGuard device_guard;
auto primary_ctx_device_index = get_device_index_with_primary_context();
if (primary_ctx_device_index >= 0) {
device_guard.reset_device(at::Device(at::DeviceType::CUDA, primary_ctx_device_index));
}
// note that cudaHostAlloc may not touch pointer if size is 0
*ptr = 0;

View File

@ -108,7 +108,7 @@ private:
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.set_device(3); // initializes!
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
@ -118,7 +118,7 @@ private:
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using set_device.
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() : guard_() {}
/// Initialize the guard, setting the current device to the passed Device.
@ -159,7 +159,7 @@ public:
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
/// either from construction, or via reset_device.
optional<Device> current_device() const {
return guard_.current_device();
}

View File

@ -2,7 +2,7 @@
#include "c10/util/Exception.h"
#include "c10/macros/Macros.h"
#include "cuda.h"
#include <cuda.h>
// Note [CHECK macro]
// ~~~~~~~~~~~~~~~~~~

View File

@ -18,6 +18,7 @@ import warnings
import random
import contextlib
import socket
import subprocess
import time
from collections import OrderedDict
from contextlib import contextmanager
@ -46,9 +47,12 @@ torch.backends.cudnn.disable_global_flags()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--subprocess', action='store_true',
help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
args, remaining = parser.parse_known_args()
TEST_IN_SUBPROCESS = args.subprocess
SEED = args.seed
if not expecttest.ACCEPT:
expecttest.ACCEPT = args.accept
@ -56,8 +60,61 @@ UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
def shell(command, cwd=None):
sys.stdout.flush()
sys.stderr.flush()
# The following cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
try:
return p.wait()
except KeyboardInterrupt:
# Give `p` a chance to handle KeyboardInterrupt. Without this,
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
exit_status = p.wait(timeout=5)
if exit_status is not None:
return exit_status
else:
p.kill()
raise
except: # noqa E722, copied from python core library
p.kill()
raise
finally:
# Always call p.wait() to ensure exit
p.wait()
def run_tests(argv=UNITTEST_ARGS):
unittest.main(argv=argv)
if TEST_IN_SUBPROCESS:
suite = unittest.TestLoader().loadTestsFromModule(__main__)
test_cases = []
def add_to_test_cases(suite_or_case):
if isinstance(suite_or_case, unittest.TestCase):
test_cases.append(suite_or_case)
else:
for element in suite_or_case:
add_to_test_cases(element)
add_to_test_cases(suite)
failed_tests = []
for case in test_cases:
test_case_full_name = case.id().split('.', 1)[1]
exitcode = shell([sys.executable] + argv + [test_case_full_name])
if exitcode != 0:
failed_tests.append(test_case_full_name)
assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
len(failed_tests), '\n\t'.join(failed_tests))
else:
unittest.main(argv=argv)
PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)

View File

@ -14,7 +14,7 @@ import tempfile
import torch
import torch._six
from torch.utils import cpp_extension
from common_utils import TEST_WITH_ROCM
from common_utils import TEST_WITH_ROCM, shell
import torch.distributed as dist
TESTS = [
@ -98,49 +98,22 @@ def print_to_stderr(message):
print(message, file=sys.stderr)
def shell(command, cwd=None):
sys.stdout.flush()
sys.stderr.flush()
# The folloing cool snippet is copied from Py3 core library subprocess.call
# only the with
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
try:
return p.wait()
except KeyboardInterrupt:
# Give `p` a chance to handle KeyboardInterrupt. Without this,
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
exit_status = p.wait(timeout=5)
if exit_status is not None:
return exit_status
else:
p.kill()
raise
except: # noqa E722, copied from python core library
p.kill()
raise
finally:
# Always call p.wait() to ensure exit
p.wait()
def run_test(executable, test_module, test_directory, options):
def run_test(executable, test_module, test_directory, options, *extra_unittest_args):
unittest_args = options.additional_unittest_args
if options.verbose:
unittest_args.append('--verbose')
# Can't call `python -m unittest test_*` here because it doesn't run code
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
argv = [test_module + '.py'] + unittest_args
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)
command = executable + argv
return shell(command, test_directory)
def test_cuda_primary_ctx(executable, test_module, test_directory, options):
return run_test(executable, test_module, test_directory, options, '--subprocess')
def test_cpp_extensions(executable, test_module, test_directory, options):
try:
cpp_extension.verify_ninja_availability()
@ -226,6 +199,7 @@ def test_distributed(executable, test_module, test_directory, options):
CUSTOM_HANDLERS = {
'cuda_primary_ctx': test_cuda_primary_ctx,
'cpp_extensions': test_cpp_extensions,
'distributed': test_distributed,
}

View File

@ -1,9 +1,6 @@
import ctypes
import torch
from common_utils import TestCase, run_tests, skipIfRocm
import unittest
import glob
import os
# NOTE: this needs to be run in a brand new process
@ -19,49 +16,85 @@ if not TEST_CUDA:
TestCase = object # noqa: F811
_caffe2_nvrtc = None
def get_is_primary_context_created(device):
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
global _caffe2_nvrtc
if _caffe2_nvrtc is None:
path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
_caffe2_nvrtc = ctypes.cdll.LoadLibrary(path)
result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
return bool(active[0])
class TestCudaPrimaryCtx(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@skipIfRocm
def test_cuda_primary_ctx(self):
# Ensure context has not been created beforehand
self.assertFalse(get_is_primary_context_created(0))
self.assertFalse(get_is_primary_context_created(1))
CTX_ALREADY_CREATED_ERR_MSG = (
"Tests defined in test_cuda_primary_ctx.py must be run in a process "
"where CUDA contexts are never created. Use either run_test.py or add "
"--subprocess to run each test in a different subprocess.")
@skipIfRocm
def setUp(self):
for device in range(torch.cuda.device_count()):
# Ensure context has not been created beforehand
self.assertFalse(torch._C._cuda_hasPrimaryContext(device), TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_str_repr(self):
x = torch.randn(1, device='cuda:1')
# We should have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
print(x)
str(x)
repr(x)
# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_copy(self):
x = torch.randn(1, device='cuda:1')
# We should have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
y = torch.randn(1, device='cpu')
y.copy_(x)
# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_pin_memory(self):
x = torch.randn(1, device='cuda:1')
# We should have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
x = torch.randn(3, device='cpu').pin_memory()
# We should still have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
x = torch.randn(3, device='cpu', pin_memory=True)
# We should still have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
x = torch.zeros(3, device='cpu', pin_memory=True)
# We should still have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
x = torch.empty(3, device='cpu', pin_memory=True)
# We should still have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
x = x.pin_memory()
# We should still have only created context on 'cuda:1'
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
if __name__ == '__main__':
run_tests()

View File

@ -202,6 +202,19 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
Py_RETURN_NONE;
}
PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to has_primary_context");
int64_t device_index = static_cast<int64_t>(THPUtils_unpackLong(arg));
if (at::detail::getCUDAHooks().hasPrimaryContext(device_index)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_emptyCache(PyObject *_unused)
{
HANDLE_TH_ERRORS
@ -383,6 +396,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, nullptr},
{"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, nullptr},
{"_cuda_getCompiledVersion", (PyCFunction)THCPModule_getCompiledVersion, METH_NOARGS, nullptr},
{"_cuda_hasPrimaryContext", (PyCFunction) THCPModule_hasPrimaryContext, METH_O, nullptr},
{"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr},
{"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O, nullptr},
{"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O, nullptr},