mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fixes to allow more consistent build tests
This commit is contained in:
@ -15,6 +15,7 @@ TYPED_TEST_CASE(TensorGPUTest, TensorTypes);
|
||||
TYPED_TEST_CASE(TensorGPUDeathTest, TensorTypes);
|
||||
|
||||
TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) {
|
||||
if (!caffe2::HasCudaGPU()) return;
|
||||
Tensor<TypeParam, CUDAContext> tensor;
|
||||
EXPECT_EQ(tensor.ndim(), 0);
|
||||
vector<int> dims(3);
|
||||
@ -31,6 +32,7 @@ TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) {
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) {
|
||||
if (!HasCudaGPU()) return;
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
@ -57,6 +59,7 @@ TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) {
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUTest, TensorShareData) {
|
||||
if (!HasCudaGPU()) return;
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
@ -71,6 +74,7 @@ TYPED_TEST(TensorGPUTest, TensorShareData) {
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||
if (!HasCudaGPU()) return;
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
@ -89,6 +93,7 @@ TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
|
||||
if (!HasCudaGPU()) return;
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
@ -108,6 +113,7 @@ TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
|
||||
|
||||
|
||||
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||
if (!HasCudaGPU()) return;
|
||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
Tensor<TypeParam, CUDAContext> tensor;
|
||||
EXPECT_EQ(tensor.ndim(), 0);
|
||||
|
@ -5,6 +5,21 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
bool HasCudaGPU() {
|
||||
int count;
|
||||
auto err = cudaGetDeviceCount(&count);
|
||||
if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) {
|
||||
return false;
|
||||
}
|
||||
// cudaGetDeviceCount() should only return the above two errors. If
|
||||
// there are other kinds of errors, maybe you have called some other
|
||||
// cuda functions before HasCudaGPU().
|
||||
CHECK(err == cudaSuccess)
|
||||
<< "Unexpected error from cudaGetDeviceCount(). Did you run some "
|
||||
"cuda functions before calling HasCudaGPU()?";
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
int gDefaultGPUID = 0;
|
||||
}
|
||||
@ -137,6 +152,11 @@ const char* curandGetErrorString(curandStatus_t error) {
|
||||
}
|
||||
|
||||
bool Caffe2EnableCudaPeerAccess() {
|
||||
// If the current run does not have any cuda devices, do nothing.
|
||||
if (!HasCudaGPU()) {
|
||||
LOG(INFO) << "No cuda gpu present. Skipping.";
|
||||
return true;
|
||||
}
|
||||
int device_count;
|
||||
CUDA_CHECK(cudaGetDeviceCount(&device_count));
|
||||
int init_device;
|
||||
|
@ -14,6 +14,10 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// Check if the current running session has a cuda gpu present. Note that this
|
||||
// is different from having caffe2 built with cuda - it is possible that
|
||||
// caffe2 is built with cuda but there is no cuda hardware available.
|
||||
bool HasCudaGPU();
|
||||
// Sets and gets the default GPU id. If the function is not called, we will use
|
||||
// GPU 0 ast he default gpu id. If there is an operator that says it runs on the
|
||||
// GPU but did not specify which GPU, this default gpuid is going to be used.
|
||||
|
@ -16,6 +16,7 @@ class MemoryPoolTest : public ::testing::Test {
|
||||
// should define it if you need to initialize the varaibles.
|
||||
// Otherwise, this can be skipped.
|
||||
void SetUp() override {
|
||||
if (!HasCudaGPU()) return;
|
||||
int device_count_;
|
||||
CUDA_CHECK(cudaGetDeviceCount(&device_count_));
|
||||
// If we test with the memory pool, initialize the memory pool.
|
||||
@ -29,6 +30,7 @@ class MemoryPoolTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
if (!HasCudaGPU()) return;
|
||||
if (UsePoolOrNot::value) {
|
||||
CHECK(CudaMemoryPool::FinalizeMemoryPool());
|
||||
}
|
||||
@ -47,6 +49,7 @@ TYPED_TEST(MemoryPoolTest, InitializeAndFinalizeWorks) {
|
||||
}
|
||||
|
||||
TYPED_TEST(MemoryPoolTest, AllocateAndDeallocate) {
|
||||
if (!HasCudaGPU()) return;
|
||||
const int nbytes = 1048576;
|
||||
for (int i = 0; i < this->device_count_; ++i) {
|
||||
LOG(INFO) << "Device " << i << " of " << this->device_count_;
|
||||
|
@ -94,7 +94,6 @@ cc_test(
|
||||
deps = [
|
||||
":core_ops",
|
||||
":core_ops_gpu",
|
||||
":core_ops_cudnn",
|
||||
"//data/mnist:mnist_minidb",
|
||||
"//gtest:caffe2_gtest_main",
|
||||
]
|
||||
|
@ -450,8 +450,11 @@ PyObject* FeedBlob(PyObject* self, PyObject* args) {
|
||||
// Here are functions that are purely GPU-based functions to be filled.
|
||||
|
||||
PyObject* NumberOfGPUs(PyObject* self, PyObject* args) {
|
||||
int num_devices;
|
||||
if (cudaGetDeviceCount(&num_devices) != cudaSuccess) {
|
||||
int num_devices = 0;
|
||||
auto err = cudaGetDeviceCount(&num_devices);
|
||||
if (err == cudaErrorNoDevice || err == cudaErrorInsufficientDriver) {
|
||||
return Py_BuildValue("i", 0);
|
||||
} else if (err != cudaSuccess) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Runtime CUDA error.");
|
||||
return NULL;
|
||||
}
|
||||
|
@ -64,5 +64,7 @@ class TestMuji(unittest.TestCase):
|
||||
if __name__ == '__main__':
|
||||
if not workspace.has_gpu_support:
|
||||
print 'No GPU support. skipping muji test.'
|
||||
elif workspace.NumberOfGPUs() == 0:
|
||||
print 'No GPU device. Skipping gpu test.'
|
||||
else:
|
||||
unittest.main()
|
||||
|
@ -36,6 +36,8 @@ class TestWorkspaceGPU(unittest.TestCase):
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not workspace.has_gpu_support:
|
||||
print 'No GPU support. skipping gpu test.'
|
||||
print 'No GPU support. Skipping gpu test.'
|
||||
elif workspace.NumberOfGPUs() == 0:
|
||||
print 'No GPU device. Skipping gpu test.'
|
||||
else:
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user