Only involve tensor device in CUDA -> CPU copy, not current device. (#11592)

Summary:
This also unifies the device usage between the async and sync case.

Fixes https://github.com/pytorch/pytorch/issues/10832.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11592

Differential Revision: D9797355

Pulled By: gchanan

fbshipit-source-id: e496cd371111cfaf9a6c664167967b395e3d72e9
This commit is contained in:
Gregory Chanan
2018-09-13 16:15:56 -07:00
committed by Facebook Github Bot
parent 4672280b55
commit 85ff72348d
3 changed files with 70 additions and 0 deletions

View File

@ -0,0 +1,58 @@
import ctypes
import torch
from common import TestCase, run_tests, skipIfRocm
import unittest
# NOTE: this needs to be run in a brand new process
# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
# because if we do that, the TEST_CUDNN line from common_cuda will be executed
# multiple times as well during the execution of this test suite, and it will
# cause CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
if not TEST_CUDA:
print('CUDA not available, skipping tests')
TestCase = object # noqa: F811
def get_is_primary_context_created(device):
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
return bool(active[0])
class TestCudaPrimaryCtx(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@skipIfRocm
def test_cuda_primary_ctx(self):
# Ensure context has not been created beforehand
self.assertFalse(get_is_primary_context_created(0))
self.assertFalse(get_is_primary_context_created(1))
x = torch.randn(1, device='cuda:1')
# We should have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
print(x)
# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
y = torch.randn(1, device='cpu')
y.copy_(x)
# We should still have only created context on 'cuda:1'
self.assertFalse(get_is_primary_context_created(0))
self.assertTrue(get_is_primary_context_created(1))
# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
if __name__ == '__main__':
run_tests()