mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 16:44:39 +08:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 31f0f3a0d7 | |||
| da9a67ad5d | |||
| 5433ef9e2d |
10
src/accelerate/test_utils/scripts/test_notebook.py
Normal file
10
src/accelerate/test_utils/scripts/test_notebook.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Test file to ensure that in general certain situational setups for notebooks work.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
assert not torch.cuda.is_initialized(), "CUDA was initialized before the test script."
|
||||
|
||||
from accelerate import Accelerator # noqa
|
||||
|
||||
assert not torch.cuda.is_initialized(), "CUDA was initialized upon importing the `Accelerator` class."
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -6,11 +7,12 @@ from unittest.mock import patch
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
import accelerate
|
||||
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import require_bnb, require_multi_gpu, slow
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, execute_subprocess_async, require_cuda
|
||||
from accelerate.utils import patch_environment
|
||||
|
||||
|
||||
@ -331,10 +333,9 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
_ = accelerator.prepare(sgd)
|
||||
|
||||
@require_cuda
|
||||
def test_a_cuda_initialization_on_import(self):
|
||||
# Everything else is already initialized, now we just need to check that the cuda device is *not* initialized
|
||||
initialized = torch.cuda.is_initialized()
|
||||
self.assertFalse(
|
||||
initialized,
|
||||
"CUDA has been initialized somewhere after importing. Please see the imports in `test_accelerator.py` to try and locate the problem.",
|
||||
)
|
||||
def test_cuda_initialization_on_import(self):
|
||||
mod_file = inspect.getfile(accelerate.test_utils)
|
||||
script = os.path.sep.join(mod_file.split(os.path.sep)[:-1] + ["scripts", "test_notebook.py"])
|
||||
cmd = ["accelerate", "launch", script]
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
Reference in New Issue
Block a user