mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
run: pip install setuptools==59.5.0; pip install -e .[test,test_trackers]
|
||||
- name: Run Tests
|
||||
run: make test_cpu
|
||||
run: make test
|
||||
|
||||
test_examples:
|
||||
runs-on: ubuntu-latest
|
||||
|
6
Makefile
6
Makefile
@ -24,12 +24,8 @@ style:
|
||||
python utils/style_doc.py src/accelerate docs/source --max_len 119
|
||||
|
||||
# Run tests for the library
|
||||
test_cpu:
|
||||
test:
|
||||
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py
|
||||
|
||||
test_cuda:
|
||||
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/test_scheduler.py --ignore=./tests/test_cpu.py
|
||||
python -m pytest -s -v ./tests/test_cpu.py ./tests/test_scheduler.py
|
||||
|
||||
test_examples:
|
||||
python -m pytest -s -v ./tests/test_examples.py
|
||||
|
@ -2,5 +2,13 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from .testing import are_the_same_tensors, execute_subprocess_async, require_cuda, require_multi_gpu, require_tpu, slow
|
||||
from .testing import (
|
||||
are_the_same_tensors,
|
||||
execute_subprocess_async,
|
||||
require_cpu,
|
||||
require_cuda,
|
||||
require_multi_gpu,
|
||||
require_tpu,
|
||||
slow,
|
||||
)
|
||||
from .training import RegressionDataset, RegressionModel
|
||||
|
@ -56,6 +56,13 @@ def slow(test_case):
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def require_cpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available.
|
||||
"""
|
||||
return unittest.skipUnless(not torch.cuda.is_available(), "test requires only a CPU")(test_case)
|
||||
|
||||
|
||||
def require_cuda(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available.
|
||||
|
@ -15,9 +15,10 @@
|
||||
import unittest
|
||||
|
||||
from accelerate import debug_launcher
|
||||
from accelerate.test_utils import test_script
|
||||
from accelerate.test_utils import require_cpu, test_script
|
||||
|
||||
|
||||
@require_cpu
|
||||
class MultiCPUTester(unittest.TestCase):
|
||||
def test_cpu(self):
|
||||
debug_launcher(test_script.main)
|
||||
|
@ -18,6 +18,7 @@ from functools import partial
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator, debug_launcher
|
||||
from accelerate.test_utils import require_cpu
|
||||
|
||||
|
||||
def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):
|
||||
@ -46,6 +47,7 @@ def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_ba
|
||||
), f"Wrong lr found at second step, expected {expected_lr}, got {scheduler.get_last_lr()[0]}"
|
||||
|
||||
|
||||
@require_cpu
|
||||
class SchedulerTester(unittest.TestCase):
|
||||
def test_scheduler_steps_with_optimizer_single_process(self):
|
||||
debug_launcher(partial(scheduler_test, num_processes=1), num_processes=1)
|
||||
|
Reference in New Issue
Block a user