Compare commits

...

3 Commits
v0.26.0 ... fix

Author SHA1 Message Date
31f0f3a0d7 Comment 2023-08-08 13:58:26 -04:00
da9a67ad5d Test 2023-08-08 13:47:48 -04:00
5433ef9e2d Better test 2023-08-08 13:41:40 -04:00
2 changed files with 19 additions and 8 deletions

View File

@ -0,0 +1,10 @@
# Test file to ensure that in general certain situational setups for notebooks work.
import torch
assert not torch.cuda.is_initialized(), "CUDA was initialized before the test script."
from accelerate import Accelerator # noqa
assert not torch.cuda.is_initialized(), "CUDA was initialized upon importing the `Accelerator` class."

View File

@ -1,3 +1,4 @@
import inspect
import json
import os
import tempfile
@ -6,11 +7,12 @@ from unittest.mock import patch
import torch
from torch.utils.data import DataLoader, TensorDataset
import accelerate
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights
from accelerate.accelerator import Accelerator
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import require_bnb, require_multi_gpu, slow
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda
from accelerate.test_utils.testing import AccelerateTestCase, execute_subprocess_async, require_cuda
from accelerate.utils import patch_environment
@ -331,10 +333,9 @@ class AcceleratorTester(AccelerateTestCase):
_ = accelerator.prepare(sgd)
@require_cuda
def test_a_cuda_initialization_on_import(self):
# Everything else is already initialized, now we just need to check that the cuda device is *not* initialized
initialized = torch.cuda.is_initialized()
self.assertFalse(
initialized,
"CUDA has been initialized somewhere after importing. Please see the imports in `test_accelerator.py` to try and locate the problem.",
)
def test_cuda_initialization_on_import(self):
mod_file = inspect.getfile(accelerate.test_utils)
script = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts", "test_notebook.py"])
cmd = ["accelerate", "launch", script]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd, env=os.environ.copy())