Compare commits

..

1 Commits

86 changed files with 348 additions and 1632 deletions

View File

@ -82,23 +82,3 @@ jobs:
push: true
tags: huggingface/accelerate:gpu-deepspeed-release-${{needs.get-version.outputs.version}}
version-cuda-fp8-transformerengine:
name: "Latest Accelerate GPU FP8 TransformerEngine [version]"
runs-on:
group: aws-g6-4xlarge-plus
needs: get-version
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
file: docker/accelerate-gpu/Dockerfile
push: true
tags: huggingface/accelerate:gpu-fp8-transformerengine-release-${{needs.get-version.outputs.version}}

View File

@ -86,25 +86,3 @@ jobs:
huggingface/accelerate:gpu-deepspeed-nightly
huggingface/accelerate:gpu-deepspeed-nightly-${{ env.date }}
latest-cuda-fp8-transformerengine:
name: "Latest Accelerate GPU FP8 TransformerEngine [dev]"
runs-on:
group: aws-g6-4xlarge-plus
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Get current date
id: date
run: |
echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
file: benchmarks/fp8/Dockerfile
push: true
tags: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ env.date }}

View File

@ -123,15 +123,12 @@ Follow these steps to start contributing:
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
$ pip install -e ".[dev]"
$ pip install -e ".[quality]"
```
This will install all testing and linting/code quality dependencies for the library (see `quality`, `test_dev`,
`test_prod` targets in [`setup.py`](./setup.py)).
(If accelerate was already installed in the virtual environment, remove
it with `pip uninstall accelerate` before reinstalling it in editable
mode with the `-e` flag).
mode with the `-e` flag.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).

View File

@ -157,8 +157,6 @@ accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
## Launching multi-CPU run using MPI
🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
@ -258,7 +256,7 @@ pip install accelerate
- multi-GPU on several nodes (machines)
- TPU
- FP16/BFloat16 mixed precision
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine)
- DeepSpeed support (Experimental)
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
- Megatron-LM support (Experimental)

View File

@ -15,8 +15,6 @@ To run them, it's recommended to use a docker image (see the attached `Dockerfil
## Running:
There are official Docker images located at `huggingface/accelerate:gpu-fp8-transformerengine-nightly` which can be used.
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
For single GPU, run it via `python`:

View File

@ -109,8 +109,7 @@ def evaluate_model(model, dataloader, metric, accelerator=None):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
references = batch["labels"]
if accelerator is not None and accelerator.num_processes > 1:
predictions, references = accelerator.gather_for_metrics((predictions, references))
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(predictions=predictions, references=references)
return metric.compute()

View File

@ -33,7 +33,6 @@ huggingface/accelerate:{accelerator}-{nightly,release}
* `cpu`: Comes compiled off of `python:3.9-slim` and is designed for non-CUDA based workloads.
* More to come soon
* `gpu-deepspeed`: Comes compiled off of the `nvidia/cuda` image and includes core parts like `bitsandbytes` as well as the latest `deepspeed` version. Runs off python 3.10.
* `gpu-fp8-transformerengine`: Comes compiled off of `nvcr.io/nvidia/pytorch` and is specifically for running the `benchmarks/fp8` scripts on devices which support FP8 operations using the `TransformerEngine` library (RTX 4090, H100, etc)
## Nightlies vs Releases

View File

@ -219,9 +219,3 @@ During training, you may want to save the current state of the model, optimizer,
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.
Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.
<Note>
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.
</Note>

View File

@ -69,10 +69,4 @@ setting the same seed in the main random number generator in all processes.
</Tip>
<Note>
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.
</Note>
For more details about the internals, see the [Internals page](package_reference/torch_wrappers).

View File

@ -50,7 +50,7 @@ The `TransformerEngine` can receive many different arguments that customize how
* `margin`: The margin to use for the gradient scaling.
* `interval`: The interval to use for how often the scaling factor is recomputed.
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, `E4M3` for evaluation)
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
* `amax_history_len`: The length of the history to use for the scaling factor computation
* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.

View File

@ -53,8 +53,6 @@ accelerate launch path_to_script.py --args_for_the_script
To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts.
We also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run.
## Adapt training code
The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups.

View File

@ -56,7 +56,7 @@ fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: HYBRID
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
@ -117,7 +117,7 @@ fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
backend: TE
fp8_format: HYBRID
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false

View File

@ -208,13 +208,23 @@ To run it in each of these various modes, use the following commands:
- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)
### Using AWS SageMaker integration
- [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)
## Configuration zoo
In [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn
how to create your own configuration files depending on the scenario.
## Simple Multi-GPU Hardware Launcher
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
run the script to automatically launch multi GPU training on remote hardware.
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
with `pip install runhouse`, and you can refer to
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
for hardware setup instructions, or this
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
## SLURM Scripts
In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager.
@ -241,20 +251,6 @@ export PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPAT
export GPUS_PER_NODE=4
```
## Simple Multi-GPU Hardware Launcher (using an external platform)
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
run the script to automatically launch multi GPU training on remote hardware.
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
with `pip install runhouse`, and you can refer to
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
for hardware setup instructions, or this
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
## Finer Examples
While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.

View File

@ -217,7 +217,6 @@ def training_function(config, args):
# And call it at the end with no arguments
# Note: You could also refactor this outside of your training loop function
inner_training_loop()
accelerator.end_training()
def main():

View File

@ -19,10 +19,9 @@ import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.utils import set_seed
from accelerate import Accelerator, DistributedType
########################################################################
@ -126,8 +125,7 @@ def training_function(config, args):
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
config["num_epochs"] = 2
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
@ -219,11 +217,8 @@ def training_function(config, args):
model.train()
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
@ -253,6 +248,7 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
@ -265,6 +261,7 @@ def training_function(config, args):
predictions=predictions,
references=references,
)
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
@ -279,7 +276,6 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
accelerator.end_training()
def main():
@ -312,11 +308,6 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)

View File

@ -255,7 +255,6 @@ def training_function(config, args):
preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)
test_metric = metric.compute(predictions=preds, references=test_references)
accelerator.print("Average test metrics from all folds:", test_metric)
accelerator.end_training()
def main():

View File

@ -192,7 +192,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -716,7 +716,6 @@ def main():
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
accelerator.end_training()
if __name__ == "__main__":

View File

@ -222,7 +222,6 @@ def training_function(config, args):
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -399,7 +399,8 @@ def training_function(config, args):
step=epoch,
)
accelerator.end_training()
if args.with_tracking:
accelerator.end_training()
def main():

View File

@ -197,7 +197,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -202,7 +202,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -703,7 +703,6 @@ def main():
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f)
accelerator.end_training()
if __name__ == "__main__":

View File

@ -210,7 +210,6 @@ def training_function(config, args):
# And call it at the end with no arguments
# Note: You could also refactor this outside of your training loop function
inner_training_loop()
accelerator.end_training()
def main():

View File

@ -214,7 +214,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -203,7 +203,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -202,7 +202,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -236,7 +236,11 @@ def training_function(config, args):
step=epoch,
)
accelerator.end_training()
# New Code #
# When a run is finished, you should call `accelerator.end_training()`
# to close all of the open trackers
if args.with_tracking:
accelerator.end_training()
def main():

View File

@ -23,7 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate import Accelerator
########################################################################
@ -72,19 +72,12 @@ class PetsDataset(Dataset):
def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu,
mixed_precision=args.mixed_precision,
log_with="all",
project_dir=args.project_dir,
dataloader_config=dataloader_config,
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
)
else:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
@ -269,7 +262,8 @@ def training_function(config, args):
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
accelerator.end_training()
if args.with_tracking:
accelerator.end_training()
def main():
@ -304,11 +298,6 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",

View File

@ -21,7 +21,7 @@ from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate import Accelerator, DistributedType
########################################################################
@ -49,19 +49,12 @@ EVAL_BATCH_SIZE = 32
def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu,
mixed_precision=args.mixed_precision,
dataloader_config=dataloader_config,
log_with="all",
project_dir=args.project_dir,
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
)
else:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
if hasattr(args.checkpointing_steps, "isdigit"):
if args.checkpointing_steps == "epoch":
@ -201,10 +194,7 @@ def training_function(config, args):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
@ -266,7 +256,8 @@ def training_function(config, args):
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
accelerator.end_training()
if args.with_tracking:
accelerator.end_training()
def main():
@ -293,11 +284,6 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",

View File

@ -1,10 +0,0 @@
# Config Zoo
This folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can use these
direct config YAML's, or build off of them for your own YAML's.
These are highly annoted versions, aiming to teach you what each section does.
Each config can be run via `accelerate launch --config_file {file} run_me.py`
`run_me.py` will then print out how the current environment is setup (the contents of the `AcceleratorState`)

View File

@ -1,15 +0,0 @@
# Similar to FSDP, we set the distributed type as DEEPSPEED
distributed_type: DEEPSPEED
# With DeepSpeed, we utilize a deepspeed config file for the entire configuration
deepspeed_config:
# Can also be any of the config json's in accelerate/examples/deepspeed_config_templates
deepspeed_config_file: ../deepspeed_config_templates/zero_stage1_config.json
# If using ZeRO-3 and wanting to load big models in, this should be set to `true` so
# `transformers` uses the right `init` function
zero3_init_flag: false # true
# Finally we need to specify the number of GPUs to use
num_processes: 2
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file
# mixed_precision: "bf16"

View File

@ -1,18 +0,0 @@
# This config template simply setups up the TransformersEngine config (and a config for a single GPU),
# this can interop with the other configs in this folder
distributed_type: "NO"
mixed_precision: "fp8"
# Then we specify the fp8 configuration:
fp8_config:
backend: TE # Can be TE | MS-AMP
# The following are TE specific arguments.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details
amax_history_length: 1024
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1

View File

@ -1,18 +0,0 @@
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
distributed_type: FSDP
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
mixed_precision: 'bf16'
# Specify the number of GPUs to use
num_processes: 2
# Then we can specify the FSDP config
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true

View File

@ -1,6 +0,0 @@
# Specify distributed_type as `MULTI_GPU` for DDP
distributed_type: "MULTI_GPU"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
# Specify the number of GPUs to use
num_processes: 2

View File

@ -1,16 +0,0 @@
# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this folder
# Generally it's recommended to look at the SLURM config template for a more robust multi-node setup
distributed_type: MULTI_GPU
# We need to specify the current machine's rank
machine_rank: 0
# We then need to specify the IP address and port of the main process
main_process_ip: '1234'
main_process_port: 9999
# We need to specify the number of machines
num_machines: 2
# We need to specify the *total* number of processes
num_processes: 8
# And then we need to specify how rdvz comms will be handled
rdzv_backend: static # or c10d
# If the compute nodes are on the same network (cloud will more than likely be false)
same_network: false

View File

@ -1,26 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A base script which outputs the accelerate config for the given environment
"""
from accelerate import Accelerator
accelerator = Accelerator()
accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}")
if accelerator.fp8_recipe_handler is not None:
accelerator.print(f"FP8 config:\n{accelerator.fp8_recipe_handler}")
accelerator.end_training()

View File

@ -1,4 +0,0 @@
# Since this is single GPU, we don't need distributed training
distributed_type: "NO"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"

View File

@ -180,7 +180,6 @@ def training_function(config, args):
eval_metric = accurate.item() / num_elems
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
accelerator.end_training()
def main():

View File

@ -32,7 +32,7 @@ model.eval()
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(1, 512), # bs x seq_len
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
@ -49,16 +49,6 @@ model = prepare_pippy(model, split_points="auto", example_args=(input,))
# available on all GPUs
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
# Move the inputs to the first device
input = input.to("cuda:0")
@ -86,4 +76,3 @@ if PartialState().is_last_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()

View File

@ -32,7 +32,7 @@ model.eval()
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(1, 1024), # bs x seq_len
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
@ -48,16 +48,6 @@ model = prepare_pippy(model, split_points="auto", example_args=(input,))
# available on all GPUs
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
# Move the inputs to the first device
input = input.to("cuda:0")
@ -85,4 +75,3 @@ if PartialState().is_last_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()

View File

@ -27,7 +27,7 @@ model.eval()
# Input configs
# Create example inputs for the model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
@ -43,8 +43,6 @@ model = prepare_pippy(model, split_points="auto", example_kwargs=inputs)
# currently we don't support `model.generate`
# output = model.generate(**inputs, max_new_tokens=1)
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
inputs = inputs.to(0)
with torch.no_grad():
output = model(**inputs)
@ -54,4 +52,3 @@ if PartialState().is_last_process:
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))
PartialState().destroy_process_group()

View File

@ -14,21 +14,12 @@
import time
import torch
from packaging import version
from transformers import AutoModelForSeq2SeqLM
from accelerate import PartialState, prepare_pippy
from accelerate import __version__ as accelerate_version
from accelerate.utils import set_seed
if version.parse(accelerate_version) > version.parse("0.33.0"):
raise RuntimeError(
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
"Please use a lower accelerate version and `torchpippy`, which this example uses."
)
# Set the random seed to have reproducable outputs
set_seed(42)
@ -96,4 +87,3 @@ if PartialState().is_last_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()

View File

@ -185,7 +185,6 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()
def main():

View File

@ -1,12 +0,0 @@
distributed_type: FSDP
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true

View File

@ -1,43 +0,0 @@
#!/bin/bash
#SBATCH --job-name=multinode
#SBATCH -D .
#SBATCH --output=O-%x.%j
#SBATCH --error=E-%x.%j
#SBATCH --nodes=4 # number of nodes
#SBATCH --ntasks-per-node=1 # number of MP tasks
#SBATCH --gres=gpu:4 # number of GPUs per node
#SBATCH --cpus-per-task=160 # number of cores per tasks
#SBATCH --time=01:59:00 # maximum execution time (HH:MM:SS)
######################
### Set enviroment ###
######################
source activateEnvironment.sh
export GPUS_PER_NODE=4
######################
######################
#### Set network #####
######################
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
######################
export ACCELERATE_DIR="${ACCELERATE_DIR:-/accelerate}"
export LAUNCHER="accelerate launch \
--config ${ACCELERATE_DIR}/examples/slurm/fsdp_config.yaml \
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
--num_machines $SLURM_NNODES \
--rdzv_backend c10d \
--main_process_ip $head_node_ip \
--main_process_port 29500 \
"
export SCRIPT="${ACCELERATE_DIR}/examples/complete_nlp_example.py"
export SCRIPT_ARGS=" \
--mixed_precision fp16 \
--output_dir ${ACCELERATE_DIR}/examples/output \
"
# This step is necessary because accelerate launch does not handle multiline arguments properly
export CMD="$LAUNCHER $SCRIPT $SCRIPT_ARGS"
srun $CMD

View File

@ -27,7 +27,6 @@ extras["test_dev"] = [
"datasets",
"diffusers",
"evaluate",
"torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
@ -49,7 +48,7 @@ extras["sagemaker"] = [
setup(
name="accelerate",
version="0.34.2",
version="0.34.0.dev0",
description="Accelerate",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@ -71,7 +70,7 @@ setup(
},
python_requires=">=3.8.0",
install_requires=[
"numpy>=1.17,<3.0.0",
"numpy>=1.17,<2.0.0",
"packaging>=20.0",
"psutil",
"pyyaml",

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.34.2"
__version__ = "0.34.0.dev0"
from .accelerator import Accelerator
from .big_modeling import (

View File

@ -166,8 +166,8 @@ class Accelerator:
Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
etc...).
mixed_precision (`str`, *optional*):
Whether or not to use mixed precision training. Choose from 'no','fp16','bf16' or 'fp8'. Will default to
the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the
value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8'
requires the installation of transformers-engine.
gradient_accumulation_steps (`int`, *optional*, default to 1):
@ -310,7 +310,7 @@ class Accelerator:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
fsdp_plugin, FullyShardedDataParallelPlugin
):
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
if is_torch_version("<", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
if fsdp_plugin is None: # init from env variables
@ -583,12 +583,6 @@ class Accelerator:
def non_blocking(self):
return self.dataloader_config.non_blocking
@property
def use_stateful_dataloader(self):
if hasattr(self.dataloader_config, "use_stateful_dataloader"):
return self.dataloader_config.use_stateful_dataloader
return False
@property
def project_dir(self):
return self.project_configuration.project_dir
@ -2074,7 +2068,6 @@ class Accelerator:
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
@ -2734,7 +2727,9 @@ class Accelerator:
for tracker in self.trackers:
tracker.finish()
self.state.destroy_process_group()
if torch.distributed.is_initialized():
# needed when using torch.distributed.init_process_group
torch.distributed.destroy_process_group()
def save(self, obj, f, safe_serialization=False):
"""

View File

@ -127,11 +127,6 @@ def save_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
state_dict = dataloader.state_dict()
torch.save(state_dict, output_dataloader_state_dict_file)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
# GradScaler state
@ -246,12 +241,6 @@ def load_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = torch.load(input_dataloader_state_dict_file)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")
# GradScaler state

View File

@ -735,8 +735,8 @@ def get_cluster_input():
)
fp8_config["fp8_format"] = _ask_options(
"Which weight format should be used?",
["HYBRID", "E4M3"],
lambda x: "HYBRID" if x == 0 else "E4M3",
["E4M3", "HYBRID"],
lambda x: "E4M3" if x == 0 else "HYBRID",
default=0,
)
fp8_config["amax_history_length"] = _ask_field(

View File

@ -99,17 +99,13 @@ class BaseConfig:
result = {k: v for k, v in result.items() if v is not None}
return result
@staticmethod
def process_config(config_dict):
"""
Processes `config_dict` and sets default values for any missing keys
"""
@classmethod
def from_json_file(cls, json_file=None):
json_file = default_json_config_file if json_file is None else json_file
with open(json_file, encoding="utf-8") as f:
config_dict = json.load(f)
if "compute_environment" not in config_dict:
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
if "distributed_type" not in config_dict:
raise ValueError("A `distributed_type` must be specified in the config file.")
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
config_dict["num_processes"] = 1
if "mixed_precision" not in config_dict:
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
if "fp16" in config_dict: # Convert the config to the new format.
@ -123,14 +119,6 @@ class BaseConfig:
config_dict["debug"] = False
if "enable_cpu_affinity" not in config_dict:
config_dict["enable_cpu_affinity"] = False
return config_dict
@classmethod
def from_json_file(cls, json_file=None):
json_file = default_json_config_file if json_file is None else json_file
with open(json_file, encoding="utf-8") as f:
config_dict = json.load(f)
config_dict = cls.process_config(config_dict)
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0:
raise ValueError(
@ -150,7 +138,23 @@ class BaseConfig:
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
with open(yaml_file, encoding="utf-8") as f:
config_dict = yaml.safe_load(f)
config_dict = cls.process_config(config_dict)
if "compute_environment" not in config_dict:
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
if "mixed_precision" not in config_dict:
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]:
config_dict["mixed_precision"] = "no"
if "fp16" in config_dict: # Convert the config to the new format.
del config_dict["fp16"]
if "dynamo_backend" in config_dict: # Convert the config to the new format.
dynamo_backend = config_dict.pop("dynamo_backend")
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
if "use_cpu" not in config_dict:
config_dict["use_cpu"] = False
if "debug" not in config_dict:
config_dict["debug"] = False
if "enable_cpu_affinity" not in config_dict:
config_dict["enable_cpu_affinity"] = False
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0:
raise ValueError(
@ -177,7 +181,7 @@ class BaseConfig:
@dataclass
class ClusterConfig(BaseConfig):
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
num_processes: int
machine_rank: int = 0
num_machines: int = 1
gpu_ids: Optional[str] = None

View File

@ -1074,8 +1074,6 @@ def _validate_launch_command(args):
# Silently set the default here
if args.dynamo_backend is None:
args.dynamo_backend = "no"
if args.num_processes == -1:
raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
else:
if args.num_processes is None:
if args.use_xpu and is_xpu_available():

View File

@ -20,7 +20,7 @@ import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
from .logging import get_logger
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available
from .utils import (
RNGType,
broadcast,
@ -30,7 +30,6 @@ from .utils import (
get_data_structure,
initialize_tensors,
is_torch_version,
is_torchdata_stateful_dataloader_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
@ -365,13 +364,6 @@ class DataLoaderStateMixin:
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
batch size
<Tip warning={true}>
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
`self.gradient_state`.
</Tip>
"""
def __init_subclass__(cls, **kwargs):
@ -396,94 +388,9 @@ class DataLoaderStateMixin:
self.gradient_state._remove_dataloader(self)
class DataLoaderAdapter:
class DataLoaderShard(DataLoader, DataLoaderStateMixin):
"""
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
self.use_stateful_dataloader = use_stateful_dataloader
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
raise ImportError(
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
)
if use_stateful_dataloader:
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
def __getattr__(self, name):
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
if name == "base_dataloader":
raise AttributeError()
# Delegate attribute access to the internal dataloader
return getattr(self.base_dataloader, name)
def state_dict(self):
return self.dl_state_dict
def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
"""
return self.base_dataloader.__class__
def __len__(self):
return len(self.base_dataloader)
def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
overridden.
This should modify `self.dl_state_dict` directly
"""
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
factor = PartialState().num_processes - 1
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
self.dl_state_dict["_sampler_iter_yielded"] -= factor
if self.dl_state_dict["_num_yielded"] > 0:
self.dl_state_dict["_num_yielded"] -= factor
if self.dl_state_dict["_index_sampler_state"] is not None:
if (
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
):
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
# what it wants to yield.
#
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
# Potentially modify the state_dict to adjust for prefetching
self.adjust_state_dict_for_prefetch()
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
@ -502,8 +409,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
A random number generator to keep synchronized across processes.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**kwargs (additional keyword arguments, *optional*):
All other keyword arguments to pass to the regular `DataLoader` initialization.
@ -523,12 +428,11 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
rng_types=None,
synchronized_generator=None,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
super().__init__(dataset, **kwargs)
self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
@ -544,7 +448,7 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
self.begin()
self.set_epoch(self.iteration)
dataloader_iter = self.base_dataloader.__iter__()
dataloader_iter = super().__iter__()
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
@ -557,7 +461,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
self._update_state_dict()
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
@ -565,7 +468,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
if batch_index >= self.skip_batches:
yield current_batch
break
@ -573,15 +475,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
self.iteration += 1
self.end()
def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (DataLoaderShard, *args[1:])
def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
@ -654,10 +547,6 @@ if is_torch_xla_available():
return super().__iter__()
def set_epoch(self, epoch: int):
if hasattr(self.dataloader, "set_epoch"):
self.dataloader.set_epoch(epoch)
@property
def total_batch_size(self):
return self._loader.total_batch_size
@ -675,10 +564,10 @@ if is_torch_xla_available():
return self._loader
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
their part of the batch.
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Args:
split_batches (`bool`, *optional*, defaults to `False`):
@ -690,8 +579,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**Available attributes:**
@ -707,7 +594,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
dataset,
split_batches: bool = False,
skip_batches=0,
use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
@ -720,13 +606,13 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
super().__init__(dataset, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
self.gradient_state = GradientState()
self.state = PartialState()
self.state = AcceleratorState()
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
@ -741,14 +627,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
@ -789,9 +673,9 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = self.base_dataloader.__iter__()
main_iterator = super().__iter__()
elif self.state.process_index == 0:
main_iterator = self.base_dataloader.__iter__()
main_iterator = super().__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
@ -849,7 +733,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
@ -861,13 +744,13 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
def __len__(self):
whole_length = len(self.base_dataloader)
whole_length = super().__len__()
if self.split_batches:
return whole_length
elif self._drop_last:
@ -875,15 +758,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
else:
return math.ceil(whole_length / self.state.num_processes)
def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (DataLoaderDispatcher, *args[1:])
@property
def total_batch_size(self):
return (
@ -938,7 +812,6 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@ -952,9 +825,10 @@ def prepare_data_loader(
device (`torch.device`):
The target device for the returned `DataLoader`.
num_processes (`int`, *optional*):
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
The number of processes running concurrently. Will default to the value given by
[`~state.AcceleratorState`].
process_index (`int`, *optional*):
The index of the current process. Will default to the value given by [`~state.PartialState`].
The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
split_batches (`bool`, *optional*, defaults to `False`):
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
@ -999,10 +873,6 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
"If set to true, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
Returns:
@ -1023,8 +893,8 @@ def prepare_data_loader(
if dispatch_batches and not put_on_device:
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
# Grab defaults from PartialState
state = PartialState()
# Grab defaults from AcceleratorState
state = AcceleratorState()
if num_processes is None:
num_processes = state.num_processes
if process_index is None:
@ -1136,7 +1006,6 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
elif sampler_is_batch_sampler:
@ -1149,7 +1018,6 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
@ -1161,7 +1029,6 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
@ -1175,7 +1042,6 @@ def prepare_data_loader(
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
Should not be used if the original dataloader is a `StatefulDataLoader`.
"""
def __init__(self, batch_sampler, skip_batches=0):
@ -1195,10 +1061,9 @@ class SkipBatchSampler(BatchSampler):
return len(self.batch_sampler) - self.skip_batches
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
class SkipDataLoader(DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
@ -1209,36 +1074,19 @@ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
self.skip_batches = skip_batches
self.gradient_state = GradientState()
def __iter__(self):
self.begin()
for index, batch in enumerate(self.base_dataloader.__iter__()):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch
self.end()
def __len__(self):
return len(self.base_dataloader) - self.skip_batches
def __reduce__(self):
"""
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
`__class__` member.
"""
args = super().__reduce__()
return (SkipDataLoader, *args[1:])
def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
the original dataloader is a `StatefulDataLoader`.
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
state = PartialState()
if state.distributed_type == DistributedType.XLA:

View File

@ -79,21 +79,22 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
`AcceleratorState.num_processes`
"""
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage
# We need to annotate the split points in the model for PiPPy
state = PartialState()
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
pipe = pipeline(
model,
mb_args=args,
mb_kwargs=kwargs,
split_spec=split_spec,
)
stage = pipe.build_stage(state.local_process_index, device=state.device)
schedule = ScheduleGPipe(stage, num_chunks)
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
if args is not None:
args = pad_input_tensors(args, found_batch_size, num_chunks)
if kwargs is not None:
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
return schedule
return stage
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
@ -142,12 +143,11 @@ def prepare_pippy(
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of model inputs):
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
this method if possible.
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of model inputs)
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
recommended unless the prior condition is true for all cases.
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`, defaults to the number of available GPUs):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
@ -155,7 +155,10 @@ def prepare_pippy(
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
"""
if not is_pippy_available():
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
@ -174,7 +177,7 @@ def prepare_pippy(
model.hf_split_points = split_points
def forward(*args, **kwargs):
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`

View File

@ -125,15 +125,13 @@ class AcceleratedOptimizer(torch.optim.Optimizer):
"""
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
"""
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
return self.optimizer.train()
def eval(self):
"""
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
"""
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
self.optimizer.eval()
return self.optimizer.eval()
def step(self, closure=None):
if is_lomo_available():

View File

@ -695,14 +695,12 @@ class PartialState:
return torch.device("mlu")
elif is_musa_available():
return torch.device("musa")
# NPU should be checked before CUDA when using `transfer_to_npu`
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
elif is_npu_available():
return torch.device("npu")
elif torch.cuda.is_available():
return torch.device("cuda")
elif is_xpu_available():
return torch.device("xpu:0")
elif is_npu_available():
return torch.device("npu")
else:
return torch.device("cpu")
@ -726,15 +724,13 @@ class PartialState:
elif is_musa_available():
backend = "mccl"
distributed_type = DistributedType.MULTI_MUSA
# NPU should be checked before CUDA when using `transfer_to_npu`
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
elif torch.cuda.is_available():
if backend is None:
backend = "nccl"
distributed_type = DistributedType.MULTI_GPU
elif is_npu_available():
backend = "hccl"
distributed_type = DistributedType.MULTI_NPU
if distributed_type is None and (
int(os.environ.get("LOCAL_RANK", -1)) != -1
@ -789,16 +785,6 @@ class PartialState:
self.device = torch.device(device, device_index)
device_module.set_device(self.device)
def destroy_process_group(self, group=None):
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
"""
if self.fork_launched and group is None:
return
# needed when using torch.distributed.init_process_group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group(group)
def __getattr__(self, name: str):
# By this point we know that no attributes of `self` contain `name`,
# so we just modify the error message
@ -993,18 +979,6 @@ class AcceleratorState:
if reset_partial_state:
PartialState._reset_state()
def destroy_process_group(self, group=None):
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
"""
PartialState().destroy_process_group(group)
@property
def fork_launched(self):
return PartialState().fork_launched
@property
def use_distributed(self):
"""

View File

@ -15,7 +15,6 @@ from .testing import (
DEFAULT_LAUNCH_COMMAND,
are_the_same_tensors,
assert_exception,
capture_call_output,
device_count,
execute_subprocess_async,
get_launch_command,

View File

@ -223,7 +223,6 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
json.dump(state, f)
accelerator.end_training()
def main():

View File

@ -294,7 +294,6 @@ def main():
if accelerator.is_local_main_process:
print("**Test that `drop_last` is taken into account**")
test_gather_for_metrics_drop_last()
accelerator.end_training()
accelerator.state._reset_state()

View File

@ -240,7 +240,6 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
json.dump(train_total_peak_memory, f)
accelerator.end_training()
def main():

View File

@ -205,7 +205,6 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(performance_metric, f)
accelerator.end_training()
def main():

View File

@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models import resnet34
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
T5Config,
T5ForConditionalGeneration,
)
from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.utils import DistributedType, set_seed
from accelerate.utils import DistributedType, send_to_device, set_seed
model_to_config = {
"t5": (T5ForConditionalGeneration, T5Config, 1024),
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}
@ -38,35 +42,23 @@ def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
trace_input = torch.randint(size=(1, seq_len), **kwargs)
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
return model, trace_input, inference_inputs
def test_bert(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inference_inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
return model, torch.randint(
low=0,
high=model_config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)
def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inference_inputs.to("cuda")
inputs = inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
@ -76,41 +68,62 @@ def test_gpt2(batch_size: int = 2):
assert output is not None, "Output was not generated in the last process!"
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
# def test_resnet(batch_size: int = 2):
# set_seed(42)
# state = PartialState()
# model = resnet34()
# input_tensor = torch.rand(1, 3, 224, 224)
# model = prepare_pippy(
# model,
# example_args=(input_tensor,),
# )
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
# inputs = send_to_device(inference_inputs, "cuda:0")
# with torch.no_grad():
# output = model(inputs)
# # Zach: Check that we just grab the real outputs we need at the end
# if not state.is_last_process:
# assert output is None, "Output was not generated on just the last process!"
# else:
# assert output is not None, "Output was not generated in the last process!"
def test_t5(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
no_split_module_classes=model._no_split_modules,
example_kwargs=example_inputs,
)
# For inference args need to be a tuple
inputs = send_to_device(example_inputs, "cuda:0")
with torch.no_grad():
output = model(*inputs.values())
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
def test_resnet(batch_size: int = 2):
set_seed(42)
state = PartialState()
model = resnet34()
input_tensor = torch.rand(batch_size, 3, 224, 224)
model = prepare_pippy(
model,
example_args=(input_tensor,),
)
inputs = send_to_device(input_tensor, "cuda:0")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
try:
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing BERT...")
test_bert()
else:
print("Less than two GPUs found, not running tests!")
finally:
state.destroy_process_group()
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing T5...")
test_t5()
test_t5(1)
test_t5(3)
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
else:
print("Less than two GPUs found, not running tests!")

View File

@ -13,7 +13,7 @@
# limitations under the License.
import torch
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
class MockModel(torch.nn.Module):
@ -71,7 +71,6 @@ def main():
]:
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
PartialState().destroy_process_group()
if __name__ == "__main__":

View File

@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import tempfile
import warnings
from typing import List
from unittest.mock import Mock
@ -78,17 +77,12 @@ def create_accelerator(even_batches=True):
return accelerator
def create_dataloader(
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
):
def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False):
"""
Create a simple DataLoader to use during the test cases
"""
values = torch.as_tensor(range(dataset_size))
if shuffle:
values = values[torch.randperm(values.size(0))]
if iterable:
dataset = DummyIterableDataset(values)
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
else:
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
@ -248,16 +242,6 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
assert "only supported for map-style datasets" in str(w[-1].message)
def test_pickle_accelerator():
accelerator = create_accelerator()
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
_ = accelerator.prepare(data_loader)
pickled_accelerator = pickle.dumps(accelerator)
unpickled_accelerator = pickle.loads(pickled_accelerator)
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
def test_data_loader(data_loader, accelerator):
# Prepare the DataLoader
data_loader = accelerator.prepare(data_loader)
@ -276,81 +260,6 @@ def test_data_loader(data_loader, accelerator):
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
def test_stateful_dataloader(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
state_dict = prepared_dl.state_dict()
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
prepared_dl.load_state_dict(state_dict)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def test_stateful_dataloader_save_state(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
and then resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
with tempfile.TemporaryDirectory() as tmpdir:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
accelerator.save_state(tmpdir)
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
accelerator.load_state(tmpdir)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
@ -379,9 +288,6 @@ def main():
test_join_raises_warning_for_non_ddp_distributed(accelerator)
accelerator.state.distributed_type = original_state
accelerator.print("Test pickling an accelerator")
test_pickle_accelerator()
dataset = DummyDataset()
# Conventional Dataloader with shuffle=False
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
@ -400,10 +306,6 @@ def main():
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
test_stateful_dataloader(accelerator)
test_stateful_dataloader_save_state(accelerator)
accelerator.end_training()
if __name__ == "__main__":

View File

@ -158,4 +158,3 @@ if __name__ == "__main__":
if accelerator.is_main_process:
shutil.rmtree(out_path)
accelerator.wait_for_everyone()
accelerator.end_training()

View File

@ -110,8 +110,6 @@ def main():
if is_bnb_available():
print("Test problematic imports (bnb)")
test_problematic_imports()
if NUM_PROCESSES > 1:
PartialState().destroy_process_group()
if __name__ == "__main__":

View File

@ -173,7 +173,6 @@ def main():
test_op_checker(state)
state.print("testing sending tensors across devices")
test_copy_tensor_to_devices(state)
state.destroy_process_group()
if __name__ == "__main__":

View File

@ -822,8 +822,6 @@ def main():
print("\n**Test reinstantiated state**")
test_reinstantiated_state()
state.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -20,7 +20,7 @@ from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
from accelerate.state import GradientState
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import DistributedType, set_seed
@ -249,9 +249,9 @@ def test_gradient_accumulation_with_opt_and_scheduler(
split_batches=False, dispatch_batches=False, sync_each_batch=False
):
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
accelerator = Accelerator(
dataloader_config=dataloader_config,
split_batches=split_batches,
dispatch_batches=dispatch_batches,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# Test that context manager behaves properly
@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(
def test_dataloader_break():
accelerator = Accelerator()
first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
@ -392,7 +392,6 @@ def main():
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
)
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
state.destroy_process_group()
def _mp_fn(index):

View File

@ -14,7 +14,6 @@
import asyncio
import inspect
import io
import os
import shutil
import subprocess
@ -53,7 +52,6 @@ from ..utils import (
is_timm_available,
is_torch_version,
is_torch_xla_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
@ -431,18 +429,6 @@ def require_trackers(test_case):
)(test_case)
def require_torchdata_stateful_dataloader(test_case):
"""
Decorator marking a test that requires torchdata.stateful_dataloader.
These tests are skipped when torchdata with stateful_dataloader module isn't installed.
"""
return unittest.skipUnless(
is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
)(test_case)
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
@ -671,19 +657,3 @@ def assert_exception(exception_class: Exception, msg: str = None) -> bool:
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
if was_ran:
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")
def capture_call_output(func, *args, **kwargs):
"""
Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
"""
captured_output = io.StringIO()
original_stdout = sys.stdout
try:
sys.stdout = captured_output
func(*args, **kwargs)
except Exception as e:
raise e
finally:
sys.stdout = original_stdout
return captured_output.getvalue()

View File

@ -107,8 +107,6 @@ from .imports import (
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
is_torchdata_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
@ -177,7 +175,7 @@ from .operations import (
send_to_device,
slice_tensors,
)
from .versions import compare_versions, is_torch_version
from .versions import compare_versions, is_torch_version, parse
if is_deepspeed_available():

View File

@ -37,9 +37,7 @@ FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHA
FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
FSDP_PYTORCH_VERSION = (
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP_PYTORCH_VERSION = "2.1.0"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]

View File

@ -313,9 +313,8 @@ class FP8RecipeKwargs(KwargsHandler):
The margin to use for the gradient scaling.
interval (`int`, *optional*, default to 1):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, default to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
`E4M3` for evaluation)
fp8_format (`str`, *optional*, default to "E4M3"):
The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
amax_history_len (`int`, *optional*, default to 1024):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, default to "most_recent"):
@ -365,7 +364,7 @@ class FP8RecipeKwargs(KwargsHandler):
if self.interval is None:
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
if self.fp8_format is None:
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3")
self.fp8_format = self.fp8_format.upper()
if self.fp8_format not in get_args(FP8Format):
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
@ -749,7 +748,7 @@ class DataLoaderConfiguration:
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
" underlying dataset is an `IterableDataset`, `False` otherwise."
" underlying dataset is an `IterableDataslet`, `False` otherwise."
},
)
even_batches: bool = field(
@ -777,13 +776,6 @@ class DataLoaderConfiguration:
" prepared dataloader has `pin_memory` set to `True` to work properly."
},
)
use_stateful_dataloader: bool = field(
default=False,
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
},
)
@dataclass
@ -1345,11 +1337,10 @@ class FullyShardedDataParallelPlugin:
},
)
sync_module_states: bool = field(
default=None,
default=False,
metadata={
"help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 "
"to ensure they are the same across all ranks after initialization. Defaults to `False` unless "
"`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled."
"to ensure they are the same across all ranks after initialization. Defaults to `True`"
},
)
forward_prefetch: bool = field(
@ -1498,9 +1489,9 @@ class FullyShardedDataParallelPlugin:
# when using `sync_module_states`
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
def set_state_dict_type(self, state_dict_type=None):
def set_state_dict_type(self):
"""
Set the state dict config based on the `StateDictType`.
Set the state dict config based on the `StateDictType.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
@ -1510,11 +1501,6 @@ class FullyShardedDataParallelPlugin:
StateDictType,
)
# Override the state_dict_type if provided, typical use case:
# user trains with sharded, but final save is with full
if state_dict_type is not None:
self.state_dict_type = state_dict_type
if self.state_dict_type is None:
self.state_dict_type = os.environ.get("FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT")
if isinstance(self.state_dict_type, str):
@ -1543,7 +1529,9 @@ class FullyShardedDataParallelPlugin:
# First base off of `_no_split_modules`
no_split_modules = getattr(model, "_no_split_modules", None)
default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []
default_transformer_cls_names_to_wrap = (
",".join(model._no_split_modules) if no_split_modules is not None else ""
)
if self.auto_wrap_policy == transformer_auto_wrap_policy:
if self.transformer_cls_names_to_wrap is None:
self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap

View File

@ -19,11 +19,9 @@ import warnings
from functools import lru_cache
import torch
from packaging import version
from packaging.version import parse
from .environment import parse_flag_from_env, str_to_bool
from .versions import compare_versions, is_torch_version
from .versions import compare_versions, is_torch_version, parse
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
@ -178,7 +176,11 @@ def is_deepspeed_available():
def is_pippy_available():
return is_torch_version(">=", "2.4.0")
package_exists = _is_package_available("pippy", "torchpippy")
if package_exists:
pippy_version = parse(importlib.metadata.version("torchpippy"))
return compare_versions(pippy_version, ">", "0.1.1")
return False
def is_bf16_available(ignore_tpu=False):
@ -195,7 +197,7 @@ def is_bf16_available(ignore_tpu=False):
def is_4bit_bnb_available():
package_exists = _is_package_available("bitsandbytes")
if package_exists:
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
return compare_versions(bnb_version, ">=", "0.39.0")
return False
@ -203,7 +205,7 @@ def is_4bit_bnb_available():
def is_8bit_bnb_available():
package_exists = _is_package_available("bitsandbytes")
if package_exists:
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
return compare_versions(bnb_version, ">=", "0.37.2")
return False
@ -251,7 +253,7 @@ def is_triton_available():
def is_aim_available():
package_exists = _is_package_available("aim")
if package_exists:
aim_version = version.parse(importlib.metadata.version("aim"))
aim_version = parse(importlib.metadata.version("aim"))
return compare_versions(aim_version, "<", "4.0.0")
return False
@ -320,7 +322,7 @@ def is_mps_available(min_version="1.12"):
def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
return str(parse(full_version).major) + "." + str(parse(full_version).minor)
_torch_version = importlib.metadata.version("torch")
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
@ -427,16 +429,3 @@ def is_xpu_available(check_device=False):
def is_dvclive_available():
return _is_package_available("dvclive")
def is_torchdata_available():
return _is_package_available("torchdata")
# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
def is_torchdata_stateful_dataloader_available():
package_exists = _is_package_available("torchdata")
if package_exists:
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
return compare_versions(torchdata_version, ">=", "0.8.0")
return False

View File

@ -20,7 +20,8 @@ from .imports import is_fp8_available
from .operations import GatheredParameters
# Do not import `transformer_engine` at package level to avoid potential issues
if is_fp8_available():
import transformer_engine.pytorch as te
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
@ -29,8 +30,6 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
"""
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te
for name, module in model.named_children():
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
@ -88,8 +87,6 @@ def has_transformer_engine_layers(model):
"""
if not is_fp8_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
import transformer_engine.pytorch as te
for m in model.modules():
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
return True
@ -101,8 +98,6 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
"""
if not is_fp8_available():
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
from transformer_engine.pytorch import fp8_autocast
def forward(self, *args, **kwargs):
@ -120,8 +115,7 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
"""
Applies FP8 context manager to the model's forward method
"""
if not is_fp8_available():
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
# Import here to keep base imports fast
import transformer_engine.common.recipe as te_recipe
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}

View File

@ -15,11 +15,20 @@
import importlib.metadata
from typing import Union
from packaging.version import Version, parse
from packaging.version import Version
from packaging.version import parse as _parse
from .constants import STR_OPERATION_TO_FUNC
def parse(version: str):
"""
Same as `packaging.version.parse`, but grabs strictly the base version.
"""
version = _parse(version)
return _parse(version.base_version)
torch_version = parse(importlib.metadata.version("torch"))

View File

@ -49,7 +49,6 @@ from accelerate.utils.other import patch_environment
set_seed(42)
BERT_BASE_CASED = "bert-base-cased"
LLAMA_TESTING = "hf-internal-testing/tiny-random-LlamaForCausalLM"
FP16 = "fp16"
BF16 = "bf16"
dtypes = [FP16, BF16]
@ -136,49 +135,39 @@ class FSDPPluginIntegration(AccelerateTestCase):
assert fsdp_plugin.state_dict_config.offload_to_cpu
assert fsdp_plugin.state_dict_config.rank0_only
# We can also override the state_dict_type,
# typical case: user trains with sharded, but final save is with full
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_type="FULL_STATE_DICT")
fsdp_plugin.set_state_dict_type("SHARDED_STATE_DICT")
assert fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT
def test_auto_wrap_policy(self):
for model_name in [LLAMA_TESTING, BERT_BASE_CASED]:
model = AutoModel.from_pretrained(model_name)
layer_to_wrap = "LlamaDecoderLayer" if model_name == LLAMA_TESTING else "BertLayer"
for policy in FSDP_AUTO_WRAP_POLICY:
env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = policy
transformer_cls_to_wrap = None
min_num_params = None
env.pop("FSDP_TRANSFORMER_CLS_TO_WRAP", None)
env.pop("FSDP_MIN_NUM_PARAMS", None)
if policy == "TRANSFORMER_BASED_WRAP":
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = layer_to_wrap
transformer_cls_to_wrap = layer_to_wrap
elif policy == "SIZE_BASED_WRAP":
env["FSDP_MIN_NUM_PARAMS"] = "2000"
min_num_params = 2000
# First test via env
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
# Then manually set the policy
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=policy,
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
min_num_params=min_num_params,
)
model = AutoModel.from_pretrained(BERT_BASE_CASED)
for policy in FSDP_AUTO_WRAP_POLICY:
env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = policy
transformer_cls_to_wrap = None
min_num_params = None
if policy == "TRANSFORMER_BASED_WRAP":
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer"
transformer_cls_to_wrap = "BertLayer"
elif policy == "SIZE_BASED_WRAP":
env["FSDP_MIN_NUM_PARAMS"] = "2000"
min_num_params = 2000
# First test via env
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
# Then manually set the policy
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=policy,
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
min_num_params=min_num_params,
)
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
import os
import pickle
@ -27,7 +26,6 @@ from torch.utils.data import DataLoader, TensorDataset
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
@ -37,20 +35,9 @@ from accelerate.test_utils import (
slow,
torch_device,
)
from accelerate.test_utils.testing import (
AccelerateTestCase,
require_cuda,
require_non_torch_xla,
require_torchdata_stateful_dataloader,
)
from accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment
from accelerate.utils.dataclasses import DataLoaderConfiguration
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
from accelerate.utils import FP8RecipeKwargs, patch_environment
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model
from accelerate.utils.random import set_seed
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader
class ModelWithTiedWeights(torch.nn.Module):
@ -71,6 +58,7 @@ def create_components(tied_weights=False):
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))
valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6])))
return model, optimizer, scheduler, train_dl, valid_dl
@ -85,21 +73,6 @@ class ModelForTest(torch.nn.Module):
return self.linear2(self.batchnorm(self.linear1(x)))
def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0):
"Generates a tuple of dummy DataLoaders to test with"
def get_dataset(n_batches):
x = torch.randn(batch_size * n_batches, 3)
y = torch.randn(batch_size * n_batches, 5)
return TensorDataset(x, y)
train_dataset = get_dataset(n_train_batches)
valid_dataset = get_dataset(n_valid_batches)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers)
return (train_dataloader, valid_dataloader)
def get_signature(model):
return sum(param.abs().sum().item() for param in model.parameters())
@ -116,12 +89,7 @@ def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch"
if len(param.args) > 1:
param_based_name += "_tied_weights" if param.args[1] is True else ""
if len(param.args) > 2:
param_based_name += f"_num_workers_{param.args[2]}"
if len(param.args) > 3:
param_based_name += "_dispatch_batches" if param.args[3] is True else "_no_dispatch_batches"
param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else ""
return f"{func.__name__}_{param_based_name}"
@ -647,156 +615,3 @@ class AcceleratorTester(AccelerateTestCase):
# check that pickle roundtrip works
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)
@parameterized.expand([True, False])
def test_can_pickle_dataloader(self, dispatch_batches):
"""
Test that pickling a prepared dataloader works.
"""
data = torch.arange(10).to(torch_device)
ds = torch.utils.data.TensorDataset(data)
dl = torch.utils.data.DataLoader(ds)
skip_dl = skip_first_batches(dl, 2)
# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
# TODO: Add support for pickling StatefulDataLoader
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
accelerator = Accelerator(dataloader_config=dataloader_config)
original_dl, _ = accelerator.prepare(dl, skip_dl)
if dispatch_batches:
assert isinstance(original_dl, DataLoaderDispatcher)
else:
assert isinstance(original_dl, DataLoaderShard)
prepared_model_dumps = pickle.dumps(accelerator)
model_loaded = pickle.loads(prepared_model_dumps)
assert len(model_loaded._dataloaders) == 2
# Assert equality of recovered and original dataloader
loaded_dl = model_loaded._dataloaders[0]
assert isinstance(loaded_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_dl) == len(original_dl)
assert [i for i in loaded_dl] == [i for i in original_dl]
# Test skip dataloader works as expected as well
loaded_skip_dl = model_loaded._dataloaders[1]
assert isinstance(loaded_skip_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_skip_dl) == len(original_dl) - 2
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]
# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
@require_torchdata_stateful_dataloader
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
"""Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object."""
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
accelerator = Accelerator(dataloader_config=dataloader_config)
model, optimizer, scheduler, train_dl, valid_dl = create_components()
(
prepared_model,
prepared_optimizer,
prepared_scheduler,
prepared_train_dl,
prepared_valid_dl,
) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
assert prepared_model in accelerator._models
assert prepared_optimizer in accelerator._optimizers
assert prepared_scheduler in accelerator._schedulers
assert prepared_train_dl in accelerator._dataloaders
assert prepared_valid_dl in accelerator._dataloaders
assert isinstance(prepared_train_dl, StatefulDataLoader)
assert isinstance(prepared_valid_dl, StatefulDataLoader)
@parameterized.expand(
itertools.product([True, False], [True, False], [0, 2], [True, False]),
name_func=parameterized_custom_name_func,
)
@require_torchdata_stateful_dataloader
def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches):
"""
Test that saving and loading a model with a stateful dataloader returns the same model,
and that the dataloader's iterator is restored properly."""
set_seed(42)
n_train_batches = 64 # Use enough batches to ensure we can get partial iterations on large compute
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)
accelerator = Accelerator(dataloader_config=dataloader_config)
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
train_dl, valid_dl = create_dataloaders_for_test(n_train_batches=n_train_batches, num_workers=num_workers)
model = ModelForTest()
(
prepared_model,
prepared_optimizer,
prepared_scheduler,
prepared_train_dl,
prepared_valid_dl,
) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
assert isinstance(prepared_train_dl, StatefulDataLoader)
assert isinstance(prepared_valid_dl, StatefulDataLoader)
# Perform 3 training iterations to ensure the dataloader's iterator is advanced
num_batches_to_skip = 3
model.train()
untrained_batches = []
with tempfile.TemporaryDirectory() as tmpdirname:
for step, batch in enumerate(prepared_train_dl):
x, y = batch
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
if step == num_batches_to_skip - 1:
# Save the state once we've gone through a few batches
accelerator.save_state(f"{tmpdirname}/state", safe_serialization=use_safetensors)
if step >= num_batches_to_skip:
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
# We then unwrap the trained model
unwrapped_model = accelerator.unwrap_model(prepared_model)
original_linear1 = unwrapped_model.linear1.weight.clone()
original_batchnorm = unwrapped_model.batchnorm.weight.clone()
original_linear2 = unwrapped_model.linear2.weight.clone()
# Resume the state
accelerator.load_state(f"{tmpdirname}/state")
# Train this to the end of the DataLoader
batches_seen_with_loaded_dl = 0
for batch in prepared_train_dl:
x, y = batch
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
batches_seen_with_loaded_dl += 1
unwrapped_model_2 = accelerator.unwrap_model(prepared_model)
new_linear1 = unwrapped_model_2.linear1.weight
new_batchnorm = unwrapped_model_2.batchnorm.weight
new_linear2 = unwrapped_model_2.linear2.weight
# Assert equalities
assert batches_seen_with_loaded_dl == len(not_skipped_batches)
assert torch.allclose(original_linear1, new_linear1)
assert torch.allclose(original_batchnorm, new_batchnorm)
assert torch.allclose(original_linear2, new_linear2)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from pathlib import Path
from unittest.mock import patch
@ -19,13 +20,13 @@ from unittest.mock import patch
import torch
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
import accelerate.commands.test as accelerate_test_cmd
from accelerate.commands.config.config_args import BaseConfig, ClusterConfig, SageMakerConfig, load_config_from_file
from accelerate.commands.estimate import estimate_command, estimate_command_parser, gather_data
from accelerate.commands.launch import _validate_launch_command, launch_command, launch_command_parser
from accelerate.commands.tpu import tpu_command_launcher, tpu_command_parser
from accelerate.commands.launch import _validate_launch_command, launch_command_parser
from accelerate.test_utils import execute_subprocess_async
from accelerate.test_utils.testing import (
capture_call_output,
DEFAULT_LAUNCH_COMMAND,
get_launch_command,
path_in_accelerate_package,
require_multi_device,
require_timm,
@ -52,7 +53,6 @@ class AccelerateLauncherTester(unittest.TestCase):
changed_path = config_folder / "_default_config.yaml"
test_config_path = Path("tests/test_configs")
parser = launch_command_parser()
@classmethod
def setUpClass(cls):
@ -65,11 +65,12 @@ class AccelerateLauncherTester(unittest.TestCase):
cls.changed_path.rename(cls.config_path)
def test_no_config(self):
args = ["--monitor_interval", "0.1", str(self.test_file_path)]
if torch.cuda.is_available() and (torch.cuda.device_count() > 1):
args = ["--multi_gpu"] + args
args = self.parser.parse_args(["--monitor_interval", "0.1", str(self.test_file_path)])
launch_command(args)
cmd = get_launch_command(multi_gpu=True)
else:
cmd = DEFAULT_LAUNCH_COMMAND
cmd.append(self.test_file_path)
execute_subprocess_async(cmd, env=os.environ.copy())
def test_config_compatibility(self):
invalid_configs = ["fp8", "invalid", "mpi", "sagemaker"]
@ -77,21 +78,20 @@ class AccelerateLauncherTester(unittest.TestCase):
if any(invalid_config in str(config) for invalid_config in invalid_configs):
continue
with self.subTest(config_file=config):
args = self.parser.parse_args(["--config_file", str(config), str(self.test_file_path)])
launch_command(args)
cmd = get_launch_command(config_file=config) + [self.test_file_path]
execute_subprocess_async(cmd)
def test_invalid_keys(self):
config_path = self.test_config_path / "invalid_keys.yaml"
with self.assertRaises(
ValueError,
RuntimeError,
msg="The config file at 'invalid_keys.yaml' had unknown keys ('another_invalid_key', 'invalid_key')",
):
args = self.parser.parse_args(["--config_file", str(config_path), str(self.test_file_path)])
launch_command(args)
cmd = get_launch_command(config_file=config_path) + [self.test_file_path]
execute_subprocess_async(cmd)
def test_accelerate_test(self):
args = accelerate_test_cmd.test_command_parser().parse_args([])
accelerate_test_cmd.test_command(args)
execute_subprocess_async(["accelerate", "test"])
@require_multi_device
def test_notebook_launcher(self):
@ -276,19 +276,18 @@ class TpuConfigTester(unittest.TestCase):
command_file = "tests/test_samples/test_command_file.sh"
gcloud = "Running gcloud compute tpus tpu-vm ssh"
def setUp(self):
self.parser = tpu_command_parser()
def test_base(self):
args = self.parser.parse_args(
["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"]
output = run_command(
self.cmd
+ ["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
def test_base_backward_compatibility(self):
args = self.parser.parse_args(
[
output = run_command(
self.cmd
+ [
"--config_file",
"tests/test_configs/0_12_0.yaml",
"--command",
@ -298,29 +297,31 @@ class TpuConfigTester(unittest.TestCase):
"--tpu_name",
self.tpu_name,
"--debug",
]
],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
def test_with_config_file(self):
args = self.parser.parse_args(["--config_file", "tests/test_configs/latest.yaml", "--debug"])
output = capture_call_output(tpu_command_launcher, args)
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--debug"], return_stdout=True
)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)
def test_with_config_file_and_command(self):
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"]
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
def test_with_config_file_and_multiple_command(self):
args = self.parser.parse_args(
[
output = run_command(
self.cmd
+ [
"--config_file",
"tests/test_configs/latest.yaml",
"--command",
@ -328,27 +329,29 @@ class TpuConfigTester(unittest.TestCase):
"--command",
'echo "Hello World"',
"--debug",
]
],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls; echo "Hello World" --worker all'
in output
)
def test_with_config_file_and_command_file(self):
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"]
output = run_command(
self.cmd
+ ["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)
def test_with_config_file_and_command_file_backward_compatibility(self):
args = self.parser.parse_args(
[
output = run_command(
self.cmd
+ [
"--config_file",
"tests/test_configs/0_12_0.yaml",
"--command_file",
@ -358,36 +361,37 @@ class TpuConfigTester(unittest.TestCase):
"--tpu_name",
self.tpu_name,
"--debug",
]
],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
in output
)
def test_accelerate_install(self):
args = self.parser.parse_args(
["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"]
output = run_command(
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate -U; echo "hello world"; echo "this is a second command" --worker all'
in output
)
def test_accelerate_install_version(self):
args = self.parser.parse_args(
[
output = run_command(
self.cmd
+ [
"--config_file",
"tests/test_configs/latest.yaml",
"--install_accelerate",
"--accelerate_version",
"12.0.0",
"--debug",
]
],
return_stdout=True,
)
output = capture_call_output(tpu_command_launcher, args)
assert (
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate==12.0.0; echo "hello world"; echo "this is a second command" --worker all'
in output

View File

@ -15,39 +15,18 @@
import random
import unittest
import pytest
import torch
from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
from accelerate import Accelerator, PartialState
from accelerate import Accelerator
from accelerate.data_loader import (
BatchSamplerShard,
DataLoaderDispatcher,
DataLoaderShard,
DataLoaderStateMixin,
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
prepare_data_loader,
skip_first_batches,
)
from accelerate.state import GradientState
from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
from accelerate.utils import is_torchdata_stateful_dataloader_available
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import (
StatefulDataLoader,
)
def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = f"num_workers_{param.args[0]}"
return f"{func.__name__}_{param_based_name}"
class RandomIterableDataset(IterableDataset):
@ -65,21 +44,6 @@ class RandomIterableDataset(IterableDataset):
stop = random.random() < self.p_stop
class SimpleIterableDataset(IterableDataset):
def __init__(self, num_samples=1000):
self.num_samples = num_samples
def __iter__(self):
for _ in range(self.num_samples):
yield torch.rand(1)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DataLoaderTester(unittest.TestCase):
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
batch_sampler_shards = [
@ -400,48 +364,11 @@ class DataLoaderTester(unittest.TestCase):
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)
def test_iterable_dataset_using_none_batch_size(self):
dataset = SimpleIterableDataset(100)
dataloader = DataLoader(dataset, batch_size=None)
dataloader = prepare_data_loader(dataloader)
for d in dataloader:
assert isinstance(d, torch.Tensor)
def test_skip_batch_sampler(self):
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]]
def test_dataloader_inheritance(self):
"""
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
are instances of DataLoader and DataLoaderStateMixin.
"""
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
# Test dataloaders are instances of base classes
assert isinstance(skip_dl, DataLoader)
assert isinstance(dl_shard, DataLoader)
assert isinstance(dl_dispatcher, DataLoader)
assert isinstance(dl_shard, DataLoaderStateMixin)
assert isinstance(dl_dispatcher, DataLoaderStateMixin)
assert isinstance(skip_dl.base_dataloader, DataLoader)
assert isinstance(dl_shard.base_dataloader, DataLoader)
assert isinstance(dl_dispatcher.base_dataloader, DataLoader)
with pytest.raises(AttributeError):
_ = DataLoaderShard.base_dataloader
def test_skip_data_loader(self):
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2)
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
@ -461,6 +388,7 @@ class DataLoaderTester(unittest.TestCase):
assert dataloader.end_of_dataloader == (idx == 3)
def test_end_of_dataloader_dispatcher(self):
Accelerator()
dataloader = DataLoaderDispatcher(range(16), batch_size=4)
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
@ -468,342 +396,3 @@ class DataLoaderTester(unittest.TestCase):
# Test it also works on the second iteration
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
class StatefulDataLoaderTester(unittest.TestCase):
@require_torchdata_stateful_dataloader
def test_skip_data_loader(self):
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
assert isinstance(dataloader, StatefulDataLoader)
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
@require_torchdata_stateful_dataloader
def test_end_of_dataloader(self):
dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)
assert dataloader.use_stateful_dataloader
assert isinstance(dataloader, StatefulDataLoader)
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
# Test it also works on the second iteration
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
@require_torchdata_stateful_dataloader
def test_end_of_dataloader_dispatcher(self):
dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
assert isinstance(dataloader, StatefulDataLoader)
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
# Test it also works on the second iteration
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_dataloader_state_dict(self, num_workers):
"""
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
"""
dataset = list(range(16))
dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
assert dataloader.use_stateful_dataloader
assert isinstance(dataloader, StatefulDataLoader)
vals = []
for idx, val in enumerate(dataloader):
vals.append(val)
if idx == 1:
sd = dataloader.state_dict()
assert len(vals) == 4
dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
dataloader2.load_state_dict(sd)
data1 = vals[2:]
data2 = list(dataloader2)
assert len(data1) == len(data2)
for d1, d2 in zip(data1, data2):
assert torch.allclose(d1, d2)
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_dataloader_dispatcher_state_dict(self, num_workers):
"""
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
"""
dataset = list(range(16))
dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
assert dataloader.use_stateful_dataloader
assert isinstance(dataloader, StatefulDataLoader)
vals = []
for idx, val in enumerate(dataloader):
vals.append(val)
if idx == 1:
sd = dataloader.state_dict()
assert len(vals) == 4
dataloader2 = DataLoaderDispatcher(
dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers
)
dataloader2.load_state_dict(sd)
data1 = vals[2:]
data2 = list(dataloader2)
assert len(data1) == len(data2)
for d1, d2 in zip(data1, data2):
assert torch.allclose(d1, d2)
@require_torchdata_stateful_dataloader
def test_dataloader_inheritance(self):
"""
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
"""
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
assert isinstance(skip_dl, StatefulDataLoader)
assert isinstance(dl_shard, StatefulDataLoader)
assert isinstance(dl_dispatcher, StatefulDataLoader)
assert isinstance(dl_shard, DataLoaderStateMixin)
assert isinstance(dl_dispatcher, DataLoaderStateMixin)
assert isinstance(skip_dl.base_dataloader, StatefulDataLoader)
assert isinstance(dl_shard.base_dataloader, StatefulDataLoader)
assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader)
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
"""
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`.
"""
dataset = list(range(64))
# Set the seed for reproducibility
def g():
return torch.Generator().manual_seed(42)
accelerator = Accelerator()
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
num_batches_to_skip = 8
def get_first_n_batches(dl, n, device):
"""
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
"""
batches = []
for idx, batch in enumerate(dl):
if idx == n - 1:
if hasattr(dl, "end"):
dl.end()
break
batches.append(batch.to(device))
return batches
# Iterate over all of the dataloaders identically, expect the same values
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device)
batches_from_dataloaders = [
get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test
]
for dl_batches in batches_from_dataloaders:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected, actual)
# The adapters should all produce the same state_dict as the reference stateful dataloader
expected_state_dict = stateful_dl.state_dict()
skip_dl_state_dict = skip_dl.state_dict()
dl_shard_state_dict = dl_shard.state_dict()
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
assert expected_state_dict == skip_dl_state_dict
assert expected_state_dict == dl_shard_state_dict
assert expected_state_dict == dl_dispatcher_state_dict
# Load the state dict into new dataloaders
manual_skip_dl = SkipDataLoader(
dataset,
batch_size=4,
num_workers=num_workers,
generator=g(),
skip_batches=num_batches_to_skip,
use_stateful_dataloader=True,
)
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
loaded_stateful_dl.load_state_dict(expected_state_dict)
loaded_skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_skip_dl.load_state_dict(expected_state_dict)
loaded_dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_shard.load_state_dict(expected_state_dict)
loaded_dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
# Continue the iteration, expecting identical behavior across the board
def get_all_batches(dl, device):
"""
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
"""
batches = []
num_batches_yielded = 0
for batch in dl:
batches.append(batch.to(device))
num_batches_yielded += 1
return (batches, num_batches_yielded)
expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device)
dataloader_batch_results = [
get_all_batches(dl, accelerator.device)
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
]
for dl_results in dataloader_batch_results:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected[0], actual[0])
assert expected_batch_results[1] == dl_results[1]
assert accelerator.gradient_state.active_dataloader is None
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
"""
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using
Accelerator (and instead using the decoupled `PartialState` workflow).
"""
dataset = list(range(64))
# Set the seed for reproducibility
def g():
return torch.Generator().manual_seed(42)
state = PartialState()
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
num_batches_to_skip = 8
def get_first_n_batches(dl, n, device):
"""
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
"""
batches = []
for idx, batch in enumerate(dl):
if idx == n - 1:
if hasattr(dl, "end"):
dl.end()
break
batches.append(batch.to(device))
return batches
# Iterate over all of the dataloaders identically, expect the same values
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)
batches_from_dataloaders = [
get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test
]
for dl_batches in batches_from_dataloaders:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected, actual)
# The adapters should all produce the same state_dict as the reference stateful dataloader
expected_state_dict = stateful_dl.state_dict()
skip_dl_state_dict = skip_dl.state_dict()
dl_shard_state_dict = dl_shard.state_dict()
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
assert expected_state_dict == skip_dl_state_dict
assert expected_state_dict == dl_shard_state_dict
assert expected_state_dict == dl_dispatcher_state_dict
# Load the state dict into new dataloaders
manual_skip_dl = SkipDataLoader(
dataset,
batch_size=4,
num_workers=num_workers,
generator=g(),
skip_batches=num_batches_to_skip,
use_stateful_dataloader=True,
)
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
loaded_stateful_dl.load_state_dict(expected_state_dict)
loaded_skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_skip_dl.load_state_dict(expected_state_dict)
loaded_dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_shard.load_state_dict(expected_state_dict)
loaded_dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
# Continue the iteration, expecting identical behavior across the board
def get_all_batches(dl, device):
"""
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
"""
batches = []
num_batches_yielded = 0
for batch in dl:
batches.append(batch.to(device))
num_batches_yielded += 1
return (batches, num_batches_yielded)
expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)
dataloader_batch_results = [
get_all_batches(dl, state.device)
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
]
for dl_results in dataloader_batch_results:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected[0], actual[0])
assert expected_batch_results[1] == dl_results[1]
# Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with
# default parameters) by `DataLoaderDispatcher`
assert GradientState._shared_state != {}, "GradientState should already be initialized!"
gradient_state = GradientState()
assert gradient_state.active_dataloader is None

View File

@ -19,7 +19,7 @@ import shutil
import tempfile
import unittest
from pathlib import Path
from unittest import mock, skip
from unittest import mock
import torch
@ -45,13 +45,11 @@ from accelerate.utils import write_basic_config
EXCLUDE_EXAMPLES = [
"cross_validation.py",
"checkpointing.py",
"gradient_accumulation.py",
"local_sgd.py",
"multi_process_metrics.py",
"memory.py",
"schedule_free.py",
"tracking.py",
"automatic_gradient_accumulation.py",
"fsdp_with_peak_mem_tracking.py",
"deepspeed_with_config_support.py",
@ -261,9 +259,6 @@ class FeatureExamplesTests(TempDirTestCase):
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
run_command(self.launch_args + testargs)
@skip(
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
)
@require_multi_device
def test_distributed_inference_examples_stable_diffusion(self):
testargs = ["examples/inference/distributed/stable_diffusion.py"]

View File

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
from accelerate.test_utils import require_transformer_engine
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
from accelerate.utils import is_import_timer_available
@ -33,7 +31,7 @@ def convert_list_to_string(data):
def run_import_time(command: str):
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
return output.stderr
@ -83,18 +81,3 @@ class ImportSpeedTester(TempDirTestCase):
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 20, err_msg)
@require_transformer_engine
class LazyImportTester(TempDirTestCase):
"""
Test suite which checks if specific packages are lazy-loaded.
Eager-import will trigger circular import in some case,
e.g. in huggingface/accelerate#3056.
"""
def test_te_import(self):
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")
self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")

View File

@ -14,9 +14,6 @@
import unittest
import numpy as np
from packaging import version
from accelerate import debug_launcher
from accelerate.test_utils import (
DEFAULT_LAUNCH_COMMAND,
@ -32,7 +29,6 @@ from accelerate.utils import patch_environment
@require_huggingface_suite
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "Test requires numpy version < 2.0")
class MetricTester(unittest.TestCase):
def setUp(self):
self.test_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_metrics.py")

View File

@ -27,7 +27,6 @@ from unittest import mock
import numpy as np
import torch
from packaging import version
# We use TF to parse the logs
from accelerate import Accelerator
@ -69,7 +68,6 @@ logger = logging.getLogger(__name__)
@require_tensorboard
class TensorBoardTrackingTest(unittest.TestCase):
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "TB doesn't support numpy 2.0")
def test_init_trackers(self):
project_name = "test_project_with_config"
with tempfile.TemporaryDirectory() as dirpath:

View File

@ -49,6 +49,7 @@ from accelerate.utils import (
listify,
pad_across_processes,
pad_input_tensors,
parse,
patch_environment,
recursively_apply,
save,
@ -411,3 +412,8 @@ class UtilsTester(unittest.TestCase):
tqdm(True, range(3), disable=True)
assert "Passing `True` as the first argument to" in cm.pop().message.args[0]
tqdm(range(3), main_process_only=True, disable=True)
def test_dev0_parsing(self):
v1 = parse("0.34.0.dev0")
v2 = parse("0.34.0")
assert v1 == v2