mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pin_memory malloc now uses existing context if available. (#22229)
Summary:
This is achieved by using `cuDevicePrimaryCtxGetState` as a way to check whether a primary context exists on a device. It is not too slow, from this benchmark of a single call to it on CUDA 10.1, Titan Xp, driver 415.27:
```
---------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------
BM_cuDevicePrimaryCtxGetState 301 ns 301 ns 2319746
```
Commits:
1. Add `CUDAHooks::getDeviceWithPrimaryContext` which returns a device index with primary context (if exists).
Link `c10/cuda` against `libcuda` for device API calls.
2. Use `getDeviceWithPrimaryContext` to check primary context in `pin_memory`.
Fix `OptionalDeviceGuard` doc.
3. Refactor `test_cuda_primary_ctx.py` to support multiple tests.
Add test for this in that file.
Fixes https://github.com/pytorch/pytorch/issues/21081.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22229
Differential Revision: D16170194
Pulled By: zou3519
fbshipit-source-id: 485a45f211b7844c9e69c63f3b3b75194a796c5d
This commit is contained in:
committed by
Facebook Github Bot
parent
054c7eb0f4
commit
8482efb203
@ -5,6 +5,7 @@
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/CUDADevice.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/PinnedMemoryAllocator.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
@ -114,6 +115,15 @@ int64_t CUDAHooks::current_device() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
|
||||
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
|
||||
"hasPrimaryContext expects valid device index, but got device_index=", device_index);
|
||||
unsigned int ctx_flags;
|
||||
int ctx_is_active;
|
||||
AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
|
||||
return ctx_is_active == 1;
|
||||
}
|
||||
|
||||
Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
|
||||
return at::cuda::getPinnedMemoryAllocator();
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCuDNN() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
int64_t current_device() const override;
|
||||
bool hasPrimaryContext(int64_t device_index) const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
bool compiledWithCuDNN() const override;
|
||||
bool compiledWithMIOpen() const override;
|
||||
|
||||
@ -57,7 +57,8 @@ namespace at { namespace cuda {
|
||||
//
|
||||
// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both
|
||||
// NVRTC and driver APIs. While the former is not yet suppoted for HIP, the
|
||||
// later is supported and needed.
|
||||
// later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext()
|
||||
// used by tensor.pin_memory()).
|
||||
//
|
||||
// The macro below strips out certain unsupported operations on HIP from the full
|
||||
// list above.
|
||||
|
||||
@ -60,15 +60,15 @@ struct CAFFE2_API CUDAHooksInterface {
|
||||
|
||||
// Initialize THCState and, transitively, the CUDA state
|
||||
virtual std::unique_ptr<THCState, void (*)(THCState*)> initCUDA() const {
|
||||
AT_ERROR("Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual Generator* getDefaultCUDAGenerator(DeviceIndex device_index = -1) const {
|
||||
AT_ERROR("Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual Device getDeviceFromPtr(void* data) const {
|
||||
AT_ERROR("Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual bool hasCUDA() const {
|
||||
@ -84,15 +84,19 @@ struct CAFFE2_API CUDAHooksInterface {
|
||||
}
|
||||
|
||||
virtual const at::cuda::NVRTC& nvrtc() const {
|
||||
AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t current_device() const {
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual bool hasPrimaryContext(int64_t device_index) const {
|
||||
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual Allocator* getPinnedMemoryAllocator() const {
|
||||
AT_ERROR("Pinned memory requires CUDA. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual bool compiledWithCuDNN() const {
|
||||
@ -112,32 +116,32 @@ struct CAFFE2_API CUDAHooksInterface {
|
||||
}
|
||||
|
||||
virtual long versionCuDNN() const {
|
||||
AT_ERROR("Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
AT_ERROR("Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual double batchnormMinEpsilonCuDNN() const {
|
||||
AT_ERROR(
|
||||
TORCH_CHECK(false,
|
||||
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
|
||||
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
|
||||
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t cuFFTGetPlanCacheSize(int64_t device_index) const {
|
||||
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual void cuFFTClearPlanCache(int64_t device_index) const {
|
||||
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int getNumGPUs() const {
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
#include <THC/THCCachingHostAllocator.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
@ -40,6 +42,24 @@ static bool BlockComparator(const BlockSize& a, const BlockSize& b)
|
||||
return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
|
||||
}
|
||||
|
||||
static int64_t inline get_device_index_with_primary_context() {
|
||||
const auto& cuda_hooks = at::detail::getCUDAHooks();
|
||||
// check current device first
|
||||
int64_t current_device_index = cuda_hooks.current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (cuda_hooks.hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
}
|
||||
}
|
||||
for (int64_t device_index = 0; device_index < cuda_hooks.getNumGPUs(); device_index++) {
|
||||
if (device_index == current_device_index) continue;
|
||||
if (cuda_hooks.hasPrimaryContext(device_index)) {
|
||||
return device_index;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct HostAllocator
|
||||
{
|
||||
typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
|
||||
@ -80,6 +100,17 @@ struct HostAllocator
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
// Pinned memory pointers allocated by any device can be directly used by any
|
||||
// other device, regardless of the current device at the time of allocation,
|
||||
// since we assume unified addressing.
|
||||
// So we grab any existing primary context, if available.
|
||||
// See pytorch/pytorch#21081.
|
||||
at::OptionalDeviceGuard device_guard;
|
||||
auto primary_ctx_device_index = get_device_index_with_primary_context();
|
||||
if (primary_ctx_device_index >= 0) {
|
||||
device_guard.reset_device(at::Device(at::DeviceType::CUDA, primary_ctx_device_index));
|
||||
}
|
||||
|
||||
// note that cudaHostAlloc may not touch pointer if size is 0
|
||||
*ptr = 0;
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ private:
|
||||
* setDevice(1);
|
||||
* OptionalDeviceGuard g;
|
||||
* setDevice(2);
|
||||
* g.set_device(3); // initializes!
|
||||
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
|
||||
*
|
||||
* On destruction, g will reset device to 2, rather than 1.
|
||||
*
|
||||
@ -118,7 +118,7 @@ private:
|
||||
*/
|
||||
class OptionalDeviceGuard {
|
||||
public:
|
||||
/// Create an uninitialized guard. Set the guard later using set_device.
|
||||
/// Create an uninitialized guard. Set the guard later using reset_device.
|
||||
explicit OptionalDeviceGuard() : guard_() {}
|
||||
|
||||
/// Initialize the guard, setting the current device to the passed Device.
|
||||
@ -159,7 +159,7 @@ public:
|
||||
}
|
||||
|
||||
/// Returns the most recent device that was set using this device guard,
|
||||
/// either from construction, or via set_device.
|
||||
/// either from construction, or via reset_device.
|
||||
optional<Device> current_device() const {
|
||||
return guard_.current_device();
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
#include "c10/util/Exception.h"
|
||||
#include "c10/macros/Macros.h"
|
||||
#include "cuda.h"
|
||||
#include <cuda.h>
|
||||
|
||||
// Note [CHECK macro]
|
||||
// ~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -18,6 +18,7 @@ import warnings
|
||||
import random
|
||||
import contextlib
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
@ -46,9 +47,12 @@ torch.backends.cudnn.disable_global_flags()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument('--subprocess', action='store_true',
|
||||
help='whether to run each test in a subprocess')
|
||||
parser.add_argument('--seed', type=int, default=1234)
|
||||
parser.add_argument('--accept', action='store_true')
|
||||
args, remaining = parser.parse_known_args()
|
||||
TEST_IN_SUBPROCESS = args.subprocess
|
||||
SEED = args.seed
|
||||
if not expecttest.ACCEPT:
|
||||
expecttest.ACCEPT = args.accept
|
||||
@ -56,8 +60,61 @@ UNITTEST_ARGS = [sys.argv[0]] + remaining
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
|
||||
def shell(command, cwd=None):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
# The following cool snippet is copied from Py3 core library subprocess.call
|
||||
# only the with
|
||||
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
|
||||
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
|
||||
# `p.wait()` in a `final` block for the code to be portable.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
|
||||
try:
|
||||
return p.wait()
|
||||
except KeyboardInterrupt:
|
||||
# Give `p` a chance to handle KeyboardInterrupt. Without this,
|
||||
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
|
||||
exit_status = p.wait(timeout=5)
|
||||
if exit_status is not None:
|
||||
return exit_status
|
||||
else:
|
||||
p.kill()
|
||||
raise
|
||||
except: # noqa E722, copied from python core library
|
||||
p.kill()
|
||||
raise
|
||||
finally:
|
||||
# Always call p.wait() to ensure exit
|
||||
p.wait()
|
||||
|
||||
|
||||
def run_tests(argv=UNITTEST_ARGS):
|
||||
unittest.main(argv=argv)
|
||||
if TEST_IN_SUBPROCESS:
|
||||
suite = unittest.TestLoader().loadTestsFromModule(__main__)
|
||||
test_cases = []
|
||||
|
||||
def add_to_test_cases(suite_or_case):
|
||||
if isinstance(suite_or_case, unittest.TestCase):
|
||||
test_cases.append(suite_or_case)
|
||||
else:
|
||||
for element in suite_or_case:
|
||||
add_to_test_cases(element)
|
||||
|
||||
add_to_test_cases(suite)
|
||||
failed_tests = []
|
||||
for case in test_cases:
|
||||
test_case_full_name = case.id().split('.', 1)[1]
|
||||
exitcode = shell([sys.executable] + argv + [test_case_full_name])
|
||||
if exitcode != 0:
|
||||
failed_tests.append(test_case_full_name)
|
||||
|
||||
assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
|
||||
len(failed_tests), '\n\t'.join(failed_tests))
|
||||
else:
|
||||
unittest.main(argv=argv)
|
||||
|
||||
PY3 = sys.version_info > (3, 0)
|
||||
PY34 = sys.version_info >= (3, 4)
|
||||
|
||||
@ -14,7 +14,7 @@ import tempfile
|
||||
import torch
|
||||
import torch._six
|
||||
from torch.utils import cpp_extension
|
||||
from common_utils import TEST_WITH_ROCM
|
||||
from common_utils import TEST_WITH_ROCM, shell
|
||||
import torch.distributed as dist
|
||||
|
||||
TESTS = [
|
||||
@ -98,49 +98,22 @@ def print_to_stderr(message):
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
|
||||
def shell(command, cwd=None):
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
# The folloing cool snippet is copied from Py3 core library subprocess.call
|
||||
# only the with
|
||||
# 1. `except KeyboardInterrupt` block added for SIGINT handling.
|
||||
# 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
|
||||
# `p.wait()` in a `final` block for the code to be portable.
|
||||
#
|
||||
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
|
||||
assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
|
||||
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
|
||||
try:
|
||||
return p.wait()
|
||||
except KeyboardInterrupt:
|
||||
# Give `p` a chance to handle KeyboardInterrupt. Without this,
|
||||
# `pytest` can't print errors it collected so far upon KeyboardInterrupt.
|
||||
exit_status = p.wait(timeout=5)
|
||||
if exit_status is not None:
|
||||
return exit_status
|
||||
else:
|
||||
p.kill()
|
||||
raise
|
||||
except: # noqa E722, copied from python core library
|
||||
p.kill()
|
||||
raise
|
||||
finally:
|
||||
# Always call p.wait() to ensure exit
|
||||
p.wait()
|
||||
|
||||
|
||||
def run_test(executable, test_module, test_directory, options):
|
||||
def run_test(executable, test_module, test_directory, options, *extra_unittest_args):
|
||||
unittest_args = options.additional_unittest_args
|
||||
if options.verbose:
|
||||
unittest_args.append('--verbose')
|
||||
# Can't call `python -m unittest test_*` here because it doesn't run code
|
||||
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
|
||||
argv = [test_module + '.py'] + unittest_args
|
||||
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)
|
||||
|
||||
command = executable + argv
|
||||
return shell(command, test_directory)
|
||||
|
||||
|
||||
def test_cuda_primary_ctx(executable, test_module, test_directory, options):
|
||||
return run_test(executable, test_module, test_directory, options, '--subprocess')
|
||||
|
||||
|
||||
def test_cpp_extensions(executable, test_module, test_directory, options):
|
||||
try:
|
||||
cpp_extension.verify_ninja_availability()
|
||||
@ -226,6 +199,7 @@ def test_distributed(executable, test_module, test_directory, options):
|
||||
|
||||
|
||||
CUSTOM_HANDLERS = {
|
||||
'cuda_primary_ctx': test_cuda_primary_ctx,
|
||||
'cpp_extensions': test_cpp_extensions,
|
||||
'distributed': test_distributed,
|
||||
}
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
import ctypes
|
||||
import torch
|
||||
from common_utils import TestCase, run_tests, skipIfRocm
|
||||
import unittest
|
||||
import glob
|
||||
import os
|
||||
|
||||
# NOTE: this needs to be run in a brand new process
|
||||
|
||||
@ -19,49 +16,85 @@ if not TEST_CUDA:
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
_caffe2_nvrtc = None
|
||||
|
||||
|
||||
def get_is_primary_context_created(device):
|
||||
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||
global _caffe2_nvrtc
|
||||
if _caffe2_nvrtc is None:
|
||||
path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_caffe2_nvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||
return bool(active[0])
|
||||
|
||||
|
||||
class TestCudaPrimaryCtx(TestCase):
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@skipIfRocm
|
||||
def test_cuda_primary_ctx(self):
|
||||
# Ensure context has not been created beforehand
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertFalse(get_is_primary_context_created(1))
|
||||
CTX_ALREADY_CREATED_ERR_MSG = (
|
||||
"Tests defined in test_cuda_primary_ctx.py must be run in a process "
|
||||
"where CUDA contexts are never created. Use either run_test.py or add "
|
||||
"--subprocess to run each test in a different subprocess.")
|
||||
|
||||
@skipIfRocm
|
||||
def setUp(self):
|
||||
for device in range(torch.cuda.device_count()):
|
||||
# Ensure context has not been created beforehand
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(device), TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_str_repr(self):
|
||||
x = torch.randn(1, device='cuda:1')
|
||||
|
||||
# We should have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
print(x)
|
||||
str(x)
|
||||
repr(x)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_copy(self):
|
||||
x = torch.randn(1, device='cuda:1')
|
||||
|
||||
# We should have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
y = torch.randn(1, device='cpu')
|
||||
y.copy_(x)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_pin_memory(self):
|
||||
x = torch.randn(1, device='cuda:1')
|
||||
|
||||
# We should have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.randn(3, device='cpu').pin_memory()
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.randn(3, device='cpu', pin_memory=True)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.zeros(3, device='cpu', pin_memory=True)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.empty(3, device='cpu', pin_memory=True)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = x.pin_memory()
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -202,6 +202,19 @@ PyObject * THCPModule_cudaUnlockMutex(PyObject *module)
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to has_primary_context");
|
||||
int64_t device_index = static_cast<int64_t>(THPUtils_unpackLong(arg));
|
||||
if (at::detail::getCUDAHooks().hasPrimaryContext(device_index)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_emptyCache(PyObject *_unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -383,6 +396,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
||||
{"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, nullptr},
|
||||
{"_cuda_getDriverVersion", (PyCFunction)THCPModule_getDriverVersion, METH_NOARGS, nullptr},
|
||||
{"_cuda_getCompiledVersion", (PyCFunction)THCPModule_getCompiledVersion, METH_NOARGS, nullptr},
|
||||
{"_cuda_hasPrimaryContext", (PyCFunction) THCPModule_hasPrimaryContext, METH_O, nullptr},
|
||||
{"_cuda_emptyCache", (PyCFunction) THCPModule_emptyCache, METH_NOARGS, nullptr},
|
||||
{"_cuda_memoryAllocated", (PyCFunction) THCPModule_memoryAllocated, METH_O, nullptr},
|
||||
{"_cuda_maxMemoryAllocated", (PyCFunction) THCPModule_maxMemoryAllocated, METH_O, nullptr},
|
||||
|
||||
Reference in New Issue
Block a user