mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Update quality tools to 2023 (#1046)
* Setup 2023 tooling for quality * Result of styling * Simplify inits and remove isort and flake8 from doc * Puts back isort skip flag
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@ -138,4 +138,7 @@ dmypy.json
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# More test things
|
# More test things
|
||||||
wandb
|
wandb
|
||||||
|
|
||||||
|
# ruff
|
||||||
|
.ruff_cache
|
||||||
|
@ -152,7 +152,7 @@ Follow these steps to start contributing:
|
|||||||
$ make test
|
$ make test
|
||||||
```
|
```
|
||||||
|
|
||||||
`accelerate` relies on `black` and `isort` to format its source code
|
`accelerate` relies on `black` and `ruff` to format its source code
|
||||||
consistently. After you make changes, apply automatic style corrections and code verifications
|
consistently. After you make changes, apply automatic style corrections and code verifications
|
||||||
that can't be automated in one go with:
|
that can't be automated in one go with:
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ Follow these steps to start contributing:
|
|||||||
$ make style
|
$ make style
|
||||||
```
|
```
|
||||||
|
|
||||||
`accelerate` also uses `flake8` and a few custom scripts to check for coding mistakes. Quality
|
`accelerate` also uses a few custom scripts to check for coding mistakes. Quality
|
||||||
control runs in CI, however you can also run the same checks with:
|
control runs in CI, however you can also run the same checks with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
5
Makefile
5
Makefile
@ -13,14 +13,13 @@ extra_quality_checks:
|
|||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
quality:
|
quality:
|
||||||
black --check $(check_dirs)
|
black --check $(check_dirs)
|
||||||
isort --check-only $(check_dirs)
|
ruff $(check_dirs)
|
||||||
flake8 $(check_dirs)
|
|
||||||
doc-builder style src/accelerate docs/source --max_len 119 --check_only
|
doc-builder style src/accelerate docs/source --max_len 119 --check_only
|
||||||
|
|
||||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||||
style:
|
style:
|
||||||
black $(check_dirs)
|
black $(check_dirs)
|
||||||
isort $(check_dirs)
|
ruff $(check_dirs) --fix
|
||||||
doc-builder style src/accelerate docs/source --max_len 119
|
doc-builder style src/accelerate docs/source --max_len 119
|
||||||
|
|
||||||
# Run tests for the library
|
# Run tests for the library
|
||||||
|
@ -16,12 +16,12 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate.utils import compute_module_sizes
|
|
||||||
from measures_util import end_measure, log_measures, start_measure
|
from measures_util import end_measure, log_measures, start_measure
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
|
||||||
|
from accelerate.utils import compute_module_sizes
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MODELS = {
|
DEFAULT_MODELS = {
|
||||||
"gpt-j-6b": {"is_causal": True, "model": "sgugger/sharded-gpt-j-6B", "tokenizer": "EleutherAI/gpt-j-6B"},
|
"gpt-j-6b": {"is_causal": True, "model": "sgugger/sharded-gpt-j-6B", "tokenizer": "EleutherAI/gpt-j-6B"},
|
||||||
|
@ -2,9 +2,8 @@ import gc
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class PeakCPUMemory:
|
class PeakCPUMemory:
|
||||||
|
@ -290,6 +290,7 @@ You will implement the `accelerate.utils.AbstractTrainStep` or inherit from thei
|
|||||||
```python
|
```python
|
||||||
from accelerate.utils import MegatronLMDummyScheduler, GPTTrainStep, avg_losses_across_data_parallel_group
|
from accelerate.utils import MegatronLMDummyScheduler, GPTTrainStep, avg_losses_across_data_parallel_group
|
||||||
|
|
||||||
|
|
||||||
# Custom loss function for the Megatron model
|
# Custom loss function for the Megatron model
|
||||||
class GPTTrainStepWithCustomLoss(GPTTrainStep):
|
class GPTTrainStepWithCustomLoss(GPTTrainStep):
|
||||||
def __init__(self, megatron_args, **kwargs):
|
def __init__(self, megatron_args, **kwargs):
|
||||||
|
@ -14,16 +14,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# New Code #
|
# New Code #
|
||||||
import evaluate
|
import evaluate
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate.utils import find_executable_batch_size
|
from accelerate.utils import find_executable_batch_size
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate,
|
# This is a fully working simple example to use Accelerate,
|
||||||
|
@ -15,20 +15,20 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
# New Code #
|
# New Code #
|
||||||
# We'll be using StratifiedKFold for this example
|
# We'll be using StratifiedKFold for this example
|
||||||
from sklearn.model_selection import StratifiedKFold
|
from sklearn.model_selection import StratifiedKFold
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate,
|
# This is a fully working simple example to use Accelerate,
|
||||||
|
@ -31,16 +31,12 @@ import random
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import DummyOptim, DummyScheduler, set_seed
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository
|
from huggingface_hub import Repository
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
@ -55,6 +51,10 @@ from transformers import (
|
|||||||
from transformers.utils import get_full_repo_name
|
from transformers.utils import get_full_repo_name
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import DummyOptim, DummyScheduler, set_seed
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
@ -16,14 +16,14 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
from accelerate import Accelerator, DistributedType
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -31,16 +31,12 @@ import random
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import MegatronLMDummyScheduler, set_seed
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository
|
from huggingface_hub import Repository
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
@ -55,6 +51,10 @@ from transformers import (
|
|||||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import MegatronLMDummyScheduler, set_seed
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.23.0.dev0")
|
check_min_version("4.23.0.dev0")
|
||||||
|
@ -14,16 +14,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# New Code #
|
# New Code #
|
||||||
import evaluate
|
import evaluate
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate.utils import find_executable_batch_size
|
from accelerate.utils import find_executable_batch_size
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate,
|
# This is a fully working simple example to use Accelerate,
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate,
|
# This is a fully working simple example to use Accelerate,
|
||||||
|
@ -17,15 +17,15 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from timm import create_model
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
import PIL
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from timm import create_model
|
|
||||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -15,15 +15,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -17,15 +17,15 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from timm import create_model
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
import PIL
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from timm import create_model
|
|
||||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -14,15 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
# This is a fully working simple example to use Accelerate
|
# This is a fully working simple example to use Accelerate
|
||||||
|
@ -1,3 +1,17 @@
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 119
|
line-length = 119
|
||||||
target-version = ['py36']
|
target-version = ['py37']
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
# Never enforce `E501` (line length violations).
|
||||||
|
ignore = ["E501", "E741", "W605"]
|
||||||
|
select = ["E", "F", "I", "W"]
|
||||||
|
line-length = 119
|
||||||
|
|
||||||
|
# Ignore import violations in all `__init__.py` files.
|
||||||
|
[tool.ruff.per-file-ignores]
|
||||||
|
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||||
|
|
||||||
|
[tool.ruff.isort]
|
||||||
|
lines-after-imports = 2
|
||||||
|
known-first-party = ["accelerate"]
|
||||||
|
@ -4,11 +4,6 @@ ensure_newline_before_comments = True
|
|||||||
force_grid_wrap = 0
|
force_grid_wrap = 0
|
||||||
include_trailing_comma = True
|
include_trailing_comma = True
|
||||||
known_first_party = accelerate
|
known_first_party = accelerate
|
||||||
known_third_party =
|
|
||||||
numpy
|
|
||||||
torch
|
|
||||||
torch_xla
|
|
||||||
|
|
||||||
line_length = 119
|
line_length = 119
|
||||||
lines_after_imports = 2
|
lines_after_imports = 2
|
||||||
multi_line_output = 3
|
multi_line_output = 3
|
||||||
|
2
setup.py
2
setup.py
@ -16,7 +16,7 @@ from setuptools import setup
|
|||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
|
||||||
extras = {}
|
extras = {}
|
||||||
extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3", "hf-doc-builder >= 0.3.0"]
|
extras["quality"] = ["black ~= 23.1", "ruff >= 0.0.241", "hf-doc-builder >= 0.3.0"]
|
||||||
extras["docs"] = []
|
extras["docs"] = []
|
||||||
extras["test_prod"] = ["pytest", "pytest-xdist", "pytest-subtests", "parameterized"]
|
extras["test_prod"] = ["pytest", "pytest-xdist", "pytest-subtests", "parameterized"]
|
||||||
extras["test_dev"] = ["datasets", "evaluate", "transformers", "scipy", "scikit-learn", "deepspeed<0.7.0", "tqdm"]
|
extras["test_dev"] = ["datasets", "evaluate", "transformers", "scipy", "scikit-learn", "deepspeed<0.7.0", "tqdm"]
|
||||||
|
@ -1,7 +1,3 @@
|
|||||||
# flake8: noqa
|
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
|
||||||
|
|
||||||
__version__ = "0.17.0.dev0"
|
__version__ = "0.17.0.dev0"
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
|
@ -1017,7 +1017,6 @@ class Accelerator:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _prepare_deepspeed(self, *args):
|
def _prepare_deepspeed(self, *args):
|
||||||
|
|
||||||
deepspeed_plugin = self.state.deepspeed_plugin
|
deepspeed_plugin = self.state.deepspeed_plugin
|
||||||
|
|
||||||
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
|
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
|
||||||
@ -1469,7 +1468,7 @@ class Accelerator:
|
|||||||
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
|
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
|
||||||
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
||||||
|
|
||||||
>>> for (input, target) in dataloader:
|
>>> for input, target in dataloader:
|
||||||
... optimizer.zero_grad()
|
... optimizer.zero_grad()
|
||||||
... output = model(input)
|
... output = model(input)
|
||||||
... loss = loss_func(output, target)
|
... loss = loss_func(output, target)
|
||||||
@ -1504,7 +1503,7 @@ class Accelerator:
|
|||||||
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
|
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
|
||||||
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
||||||
|
|
||||||
>>> for (input, target) in dataloader:
|
>>> for input, target in dataloader:
|
||||||
... optimizer.zero_grad()
|
... optimizer.zero_grad()
|
||||||
... output = model(input)
|
... output = model(input)
|
||||||
... loss = loss_func(output, target)
|
... loss = loss_func(output, target)
|
||||||
@ -1594,7 +1593,7 @@ class Accelerator:
|
|||||||
else:
|
else:
|
||||||
# Not at the end of the dataloader, no need to adjust the tensors
|
# Not at the end of the dataloader, no need to adjust the tensors
|
||||||
return tensor
|
return tensor
|
||||||
except:
|
except Exception:
|
||||||
# Dataset had no length or raised an error
|
# Dataset had no length or raised an error
|
||||||
return tensor
|
return tensor
|
||||||
return tensor
|
return tensor
|
||||||
@ -2349,7 +2348,7 @@ class Accelerator:
|
|||||||
>>> accelerator = Accelerator()
|
>>> accelerator = Accelerator()
|
||||||
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
|
||||||
|
|
||||||
>>> for (input, target) in accelerator.skip_first_batches(dataloader, num_batches=2):
|
>>> for input, target in accelerator.skip_first_batches(dataloader, num_batches=2):
|
||||||
... optimizer.zero_grad()
|
... optimizer.zero_grad()
|
||||||
... output = model(input)
|
... output = model(input)
|
||||||
... loss = loss_func(output, target)
|
... loss = loss_func(output, target)
|
||||||
|
@ -169,7 +169,7 @@ def load_accelerator_state(
|
|||||||
if is_tpu_available():
|
if is_tpu_available():
|
||||||
xm.set_rng_state(states["xm_seed"])
|
xm.set_rng_state(states["xm_seed"])
|
||||||
logger.info("All random states loaded successfully")
|
logger.info("All random states loaded successfully")
|
||||||
except:
|
except Exception:
|
||||||
logger.info("Could not load random states")
|
logger.info("Could not load random states")
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def _ask_field(input_text, convert_value=None, default=None, error_message=None)
|
|||||||
if default is not None and len(result) == 0:
|
if default is not None and len(result) == 0:
|
||||||
return default
|
return default
|
||||||
return convert_value(result) if convert_value is not None else result
|
return convert_value(result) if convert_value is not None else result
|
||||||
except:
|
except Exception:
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
print(error_message)
|
print(error_message)
|
||||||
|
|
||||||
|
@ -25,9 +25,9 @@ from ast import literal_eval
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import psutil
|
|
||||||
from accelerate.commands.config import default_config_file, load_config_from_file
|
from accelerate.commands.config import default_config_file, load_config_from_file
|
||||||
from accelerate.commands.config.config_args import SageMakerConfig
|
from accelerate.commands.config.config_args import SageMakerConfig
|
||||||
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
|
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
|
||||||
@ -644,7 +644,7 @@ def multi_gpu_launcher(args):
|
|||||||
with patch_environment(**current_env):
|
with patch_environment(**current_env):
|
||||||
try:
|
try:
|
||||||
distrib_run.run(args)
|
distrib_run.run(args)
|
||||||
except:
|
except Exception:
|
||||||
if is_rich_available() and debug:
|
if is_rich_available() and debug:
|
||||||
console = get_console()
|
console = get_console()
|
||||||
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
||||||
@ -770,7 +770,7 @@ def deepspeed_launcher(args):
|
|||||||
with patch_environment(**current_env):
|
with patch_environment(**current_env):
|
||||||
try:
|
try:
|
||||||
distrib_run.run(args)
|
distrib_run.run(args)
|
||||||
except:
|
except Exception:
|
||||||
if is_rich_available() and debug:
|
if is_rich_available() and debug:
|
||||||
console = get_console()
|
console = get_console()
|
||||||
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
||||||
|
@ -1,5 +1 @@
|
|||||||
# flake8: noqa
|
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
|
||||||
# module, but to preserve other warnings. So, don't check this module at all
|
|
||||||
|
|
||||||
from .selection_menu import BulletMenu
|
from .selection_menu import BulletMenu
|
||||||
|
@ -18,9 +18,10 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from accelerate.commands.config.config_args import default_config_file, load_config_from_file
|
|
||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
|
from accelerate.commands.config.config_args import default_config_file, load_config_from_file
|
||||||
|
|
||||||
|
|
||||||
_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
|
_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
|
||||||
|
|
||||||
|
@ -12,11 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# flake8: noqa
|
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
|
||||||
# module, but to preserve other warnings. So, don't check this module at all
|
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
@ -25,5 +20,3 @@ warnings.warn(
|
|||||||
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils.memory import find_executable_batch_size
|
|
||||||
|
@ -1,7 +1,3 @@
|
|||||||
# flake8: noqa
|
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
|
||||||
|
|
||||||
from .testing import (
|
from .testing import (
|
||||||
are_the_same_tensors,
|
are_the_same_tensors,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
@ -19,4 +15,4 @@ from .testing import (
|
|||||||
from .training import RegressionDataset, RegressionModel
|
from .training import RegressionDataset, RegressionModel
|
||||||
|
|
||||||
|
|
||||||
from .scripts import test_script, test_sync # isort:skip
|
from .scripts import test_script, test_sync # isort: skip
|
||||||
|
@ -16,15 +16,15 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
MAX_GPU_BATCH_SIZE = 16
|
MAX_GPU_BATCH_SIZE = 16
|
||||||
|
@ -15,17 +15,17 @@
|
|||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import evaluate
|
import evaluate
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.test_utils import RegressionDataset, RegressionModel
|
from accelerate.test_utils import RegressionDataset, RegressionModel
|
||||||
from accelerate.utils import is_tpu_available, set_seed
|
from accelerate.utils import is_tpu_available, set_seed
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_basic_setup(accelerator, num_samples=82, batch_size=16):
|
def get_basic_setup(accelerator, num_samples=82, batch_size=16):
|
||||||
@ -84,7 +84,7 @@ def generate_predictions(model, dataloader, accelerator):
|
|||||||
logit, target = accelerator.gather_for_metrics((logit, target))
|
logit, target = accelerator.gather_for_metrics((logit, target))
|
||||||
logits_and_targets.append((logit, target))
|
logits_and_targets.append((logit, target))
|
||||||
logits, targs = [], []
|
logits, targs = [], []
|
||||||
for (logit, targ) in logits_and_targets:
|
for logit, targ in logits_and_targets:
|
||||||
logits.append(logit)
|
logits.append(logit)
|
||||||
targs.append(targ)
|
targs.append(targ)
|
||||||
logits, targs = torch.cat(logits), torch.cat(targs)
|
logits, targs = torch.cat(logits), torch.cat(targs)
|
||||||
|
@ -18,13 +18,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
MAX_GPU_BATCH_SIZE = 16
|
MAX_GPU_BATCH_SIZE = 16
|
||||||
|
@ -16,15 +16,15 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||||
|
|
||||||
import evaluate
|
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
|
||||||
|
|
||||||
|
|
||||||
MAX_GPU_BATCH_SIZE = 16
|
MAX_GPU_BATCH_SIZE = 16
|
||||||
|
@ -77,7 +77,6 @@ def verify_dataloader_batch_sizes(
|
|||||||
|
|
||||||
|
|
||||||
def test_default_ensures_even_batch_sizes():
|
def test_default_ensures_even_batch_sizes():
|
||||||
|
|
||||||
accelerator = create_accelerator()
|
accelerator = create_accelerator()
|
||||||
|
|
||||||
# without padding, we would expect a different number of batches
|
# without padding, we would expect a different number of batches
|
||||||
@ -144,7 +143,6 @@ def test_can_join_uneven_inputs():
|
|||||||
|
|
||||||
|
|
||||||
def test_join_raises_warning_for_non_ddp_distributed(accelerator):
|
def test_join_raises_warning_for_non_ddp_distributed(accelerator):
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
with accelerator.join_uneven_inputs([Mock()]):
|
with accelerator.join_uneven_inputs([Mock()]):
|
||||||
pass
|
pass
|
||||||
|
@ -338,7 +338,6 @@ async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=Fals
|
|||||||
|
|
||||||
|
|
||||||
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
result = loop.run_until_complete(
|
result = loop.run_until_complete(
|
||||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||||
|
@ -1,7 +1,3 @@
|
|||||||
# flake8: noqa
|
|
||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
|
||||||
# module, but to preserve other warnings. So, don't check this module at all
|
|
||||||
|
|
||||||
from .constants import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME, TORCH_LAUNCH_PARAMS
|
from .constants import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME, TORCH_LAUNCH_PARAMS
|
||||||
from .dataclasses import (
|
from .dataclasses import (
|
||||||
ComputeEnvironment,
|
ComputeEnvironment,
|
||||||
|
@ -40,7 +40,6 @@ class HfDeepSpeedConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config_file_or_dict):
|
def __init__(self, config_file_or_dict):
|
||||||
|
|
||||||
if isinstance(config_file_or_dict, dict):
|
if isinstance(config_file_or_dict, dict):
|
||||||
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
||||||
# modified it, it will not be accepted here again, since `auto` values would have been overridden
|
# modified it, it will not be accepted here again, since `auto` values would have been overridden
|
||||||
|
@ -20,7 +20,6 @@ from distutils.util import strtobool
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
|
|
||||||
from .environment import parse_flag_from_env
|
from .environment import parse_flag_from_env
|
||||||
|
@ -58,9 +58,8 @@ if is_megatron_lm_available():
|
|||||||
set_jit_fusion_options,
|
set_jit_fusion_options,
|
||||||
write_args_to_tensorboard,
|
write_args_to_tensorboard,
|
||||||
)
|
)
|
||||||
from megatron.model import BertModel
|
from megatron.model import BertModel, Float16Module, GPTModel, ModelType, T5Model
|
||||||
from megatron.model import DistributedDataParallel as LocalDDP
|
from megatron.model import DistributedDataParallel as LocalDDP
|
||||||
from megatron.model import Float16Module, GPTModel, ModelType, T5Model
|
|
||||||
from megatron.model.classification import Classification
|
from megatron.model.classification import Classification
|
||||||
from megatron.optimizer import get_megatron_optimizer
|
from megatron.optimizer import get_megatron_optimizer
|
||||||
from megatron.schedules import get_forward_backward_func
|
from megatron.schedules import get_forward_backward_func
|
||||||
@ -101,7 +100,6 @@ def model_provider_func(pre_process=True, post_process=True, add_encoder=True, a
|
|||||||
post_process=post_process,
|
post_process=post_process,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
model = Classification(
|
model = Classification(
|
||||||
num_classes=args.num_labels, num_tokentypes=2, pre_process=pre_process, post_process=post_process
|
num_classes=args.num_labels, num_tokentypes=2, pre_process=pre_process, post_process=post_process
|
||||||
)
|
)
|
||||||
@ -270,7 +268,6 @@ class MegatronLMDummyDataLoader:
|
|||||||
|
|
||||||
# Data loader only on rank 0 of each model parallel group.
|
# Data loader only on rank 0 of each model parallel group.
|
||||||
if mpu.get_tensor_model_parallel_rank() == 0:
|
if mpu.get_tensor_model_parallel_rank() == 0:
|
||||||
|
|
||||||
# Number of train/valid/test samples.
|
# Number of train/valid/test samples.
|
||||||
if args.train_samples:
|
if args.train_samples:
|
||||||
train_samples = args.train_samples
|
train_samples = args.train_samples
|
||||||
|
@ -22,7 +22,12 @@ from copy import deepcopy
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from parameterized import parameterized
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoModel, AutoModelForCausalLM, get_scheduler
|
||||||
|
from transformers.testing_utils import mockenv_context
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
from transformers.utils import is_torch_bf16_available
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
from accelerate.accelerator import Accelerator
|
from accelerate.accelerator import Accelerator
|
||||||
@ -47,11 +52,6 @@ from accelerate.utils.deepspeed import (
|
|||||||
DummyScheduler,
|
DummyScheduler,
|
||||||
)
|
)
|
||||||
from accelerate.utils.other import patch_environment
|
from accelerate.utils.other import patch_environment
|
||||||
from parameterized import parameterized
|
|
||||||
from transformers import AutoModel, AutoModelForCausalLM, get_scheduler
|
|
||||||
from transformers.testing_utils import mockenv_context
|
|
||||||
from transformers.trainer_utils import set_seed
|
|
||||||
from transformers.utils import is_torch_bf16_available
|
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@ -133,7 +133,6 @@ class DeepSpeedConfigIntegration(AccelerateTestCase):
|
|||||||
|
|
||||||
@parameterized.expand(stages, name_func=parameterized_custom_name_func)
|
@parameterized.expand(stages, name_func=parameterized_custom_name_func)
|
||||||
def test_deepspeed_plugin(self, stage):
|
def test_deepspeed_plugin(self, stage):
|
||||||
|
|
||||||
# Test zero3_init_flag will be set to False when ZeRO stage != 3
|
# Test zero3_init_flag will be set to False when ZeRO stage != 3
|
||||||
deepspeed_plugin = DeepSpeedPlugin(
|
deepspeed_plugin = DeepSpeedPlugin(
|
||||||
gradient_accumulation_steps=1,
|
gradient_accumulation_steps=1,
|
||||||
|
@ -17,6 +17,9 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoModel
|
||||||
|
from transformers.testing_utils import mockenv_context
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
from accelerate.accelerator import Accelerator
|
from accelerate.accelerator import Accelerator
|
||||||
@ -38,9 +41,6 @@ from accelerate.utils.constants import (
|
|||||||
)
|
)
|
||||||
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
|
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
|
||||||
from accelerate.utils.other import patch_environment
|
from accelerate.utils.other import patch_environment
|
||||||
from transformers import AutoModel
|
|
||||||
from transformers.testing_utils import mockenv_context
|
|
||||||
from transformers.trainer_utils import set_seed
|
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
@ -18,6 +18,7 @@ from tempfile import TemporaryDirectory
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from accelerate.big_modeling import (
|
from accelerate.big_modeling import (
|
||||||
cpu_offload,
|
cpu_offload,
|
||||||
@ -31,7 +32,6 @@ from accelerate.big_modeling import (
|
|||||||
from accelerate.hooks import remove_hook_from_submodules
|
from accelerate.hooks import remove_hook_from_submodules
|
||||||
from accelerate.test_utils import require_cuda, require_mps, require_multi_gpu, require_torch_min_version, slow
|
from accelerate.test_utils import require_cuda, require_mps, require_multi_gpu, require_torch_min_version, slow
|
||||||
from accelerate.utils import offload_state_dict
|
from accelerate.utils import offload_state_dict
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class ModelForTest(nn.Module):
|
class ModelForTest(nn.Module):
|
||||||
|
@ -2,6 +2,7 @@ import unittest
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from accelerate.commands.config.config_args import SageMakerConfig
|
from accelerate.commands.config.config_args import SageMakerConfig
|
||||||
from accelerate.commands.launch import _convert_nargs_to_dict
|
from accelerate.commands.launch import _convert_nargs_to_dict
|
||||||
from accelerate.utils import ComputeEnvironment
|
from accelerate.utils import ComputeEnvironment
|
||||||
|
Reference in New Issue
Block a user