mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Only involve tensor device in CUDA -> CPU copy, not current device. (#11592)
Summary: This also unifies the device usage between the async and sync case. Fixes https://github.com/pytorch/pytorch/issues/10832. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11592 Differential Revision: D9797355 Pulled By: gchanan fbshipit-source-id: e496cd371111cfaf9a6c664167967b395e3d72e9
This commit is contained in:
committed by
Facebook Github Bot
parent
4672280b55
commit
85ff72348d
58
test/test_cuda_primary_ctx.py
Normal file
58
test/test_cuda_primary_ctx.py
Normal file
@ -0,0 +1,58 @@
|
||||
import ctypes
|
||||
import torch
|
||||
from common import TestCase, run_tests, skipIfRocm
|
||||
import unittest
|
||||
|
||||
# NOTE: this needs to be run in a brand new process
|
||||
|
||||
# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
|
||||
# because if we do that, the TEST_CUDNN line from common_cuda will be executed
|
||||
# multiple times as well during the execution of this test suite, and it will
|
||||
# cause CUDA OOM error on Windows.
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
||||
|
||||
if not TEST_CUDA:
|
||||
print('CUDA not available, skipping tests')
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
def get_is_primary_context_created(device):
|
||||
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||
result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||
return bool(active[0])
|
||||
|
||||
|
||||
class TestCudaPrimaryCtx(TestCase):
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@skipIfRocm
|
||||
def test_cuda_primary_ctx(self):
|
||||
# Ensure context has not been created beforehand
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertFalse(get_is_primary_context_created(1))
|
||||
|
||||
x = torch.randn(1, device='cuda:1')
|
||||
|
||||
# We should have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
|
||||
print(x)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
|
||||
y = torch.randn(1, device='cpu')
|
||||
y.copy_(x)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(get_is_primary_context_created(0))
|
||||
self.assertTrue(get_is_primary_context_created(1))
|
||||
|
||||
# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
Reference in New Issue
Block a user