mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Only involve tensor device in CUDA -> CPU copy, not current device. (#11592)
Summary: This also unifies the device usage between the async and sync case. Fixes https://github.com/pytorch/pytorch/issues/10832. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11592 Differential Revision: D9797355 Pulled By: gchanan fbshipit-source-id: e496cd371111cfaf9a6c664167967b395e3d72e9
This commit is contained in:
committed by
Facebook Github Bot
parent
4672280b55
commit
85ff72348d
@ -58,6 +58,13 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
|
|||||||
|
|
||||||
{
|
{
|
||||||
THTensor *selfc = THTensor_(newContiguous)(self);
|
THTensor *selfc = THTensor_(newContiguous)(self);
|
||||||
|
int tensorDevice = THCTensor_(getDevice)(state, src);
|
||||||
|
int currentDevice;
|
||||||
|
THCudaCheck(cudaGetDevice(¤tDevice));
|
||||||
|
|
||||||
|
if (currentDevice != tensorDevice) {
|
||||||
|
THCudaCheck(cudaSetDevice(tensorDevice));
|
||||||
|
}
|
||||||
src = THCTensor_(newContiguous)(state, src);
|
src = THCTensor_(newContiguous)(state, src);
|
||||||
|
|
||||||
cudaStream_t stream = THCState_getCurrentStream(state);
|
cudaStream_t stream = THCState_getCurrentStream(state);
|
||||||
@ -68,6 +75,10 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src)
|
|||||||
stream));
|
stream));
|
||||||
THCudaCheck(cudaStreamSynchronize(stream));
|
THCudaCheck(cudaStreamSynchronize(stream));
|
||||||
|
|
||||||
|
if (currentDevice != tensorDevice) {
|
||||||
|
THCudaCheck(cudaSetDevice(currentDevice));
|
||||||
|
}
|
||||||
|
|
||||||
THCTensor_(free)(state, src);
|
THCTensor_(free)(state, src);
|
||||||
THTensor_(freeCopyTo)(selfc, self);
|
THTensor_(freeCopyTo)(selfc, self);
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ TESTS = [
|
|||||||
'cpp_extensions',
|
'cpp_extensions',
|
||||||
'c10d',
|
'c10d',
|
||||||
'cuda',
|
'cuda',
|
||||||
|
'cuda_primary_ctx',
|
||||||
'dataloader',
|
'dataloader',
|
||||||
'distributed',
|
'distributed',
|
||||||
'distributions',
|
'distributions',
|
||||||
|
58
test/test_cuda_primary_ctx.py
Normal file
58
test/test_cuda_primary_ctx.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import ctypes
|
||||||
|
import torch
|
||||||
|
from common import TestCase, run_tests, skipIfRocm
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# NOTE: this needs to be run in a brand new process
|
||||||
|
|
||||||
|
# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
|
||||||
|
# because if we do that, the TEST_CUDNN line from common_cuda will be executed
|
||||||
|
# multiple times as well during the execution of this test suite, and it will
|
||||||
|
# cause CUDA OOM error on Windows.
|
||||||
|
TEST_CUDA = torch.cuda.is_available()
|
||||||
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
||||||
|
|
||||||
|
if not TEST_CUDA:
|
||||||
|
print('CUDA not available, skipping tests')
|
||||||
|
TestCase = object # noqa: F811
|
||||||
|
|
||||||
|
|
||||||
|
def get_is_primary_context_created(device):
|
||||||
|
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||||
|
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||||
|
result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||||
|
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||||
|
return bool(active[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestCudaPrimaryCtx(TestCase):
|
||||||
|
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||||
|
@skipIfRocm
|
||||||
|
def test_cuda_primary_ctx(self):
|
||||||
|
# Ensure context has not been created beforehand
|
||||||
|
self.assertFalse(get_is_primary_context_created(0))
|
||||||
|
self.assertFalse(get_is_primary_context_created(1))
|
||||||
|
|
||||||
|
x = torch.randn(1, device='cuda:1')
|
||||||
|
|
||||||
|
# We should have only created context on 'cuda:1'
|
||||||
|
self.assertFalse(get_is_primary_context_created(0))
|
||||||
|
self.assertTrue(get_is_primary_context_created(1))
|
||||||
|
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
# We should still have only created context on 'cuda:1'
|
||||||
|
self.assertFalse(get_is_primary_context_created(0))
|
||||||
|
self.assertTrue(get_is_primary_context_created(1))
|
||||||
|
|
||||||
|
y = torch.randn(1, device='cpu')
|
||||||
|
y.copy_(x)
|
||||||
|
|
||||||
|
# We should still have only created context on 'cuda:1'
|
||||||
|
self.assertFalse(get_is_primary_context_created(0))
|
||||||
|
self.assertTrue(get_is_primary_context_created(1))
|
||||||
|
|
||||||
|
# DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_tests()
|
Reference in New Issue
Block a user