mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 08:35:38 +08:00
Compare commits
1 Commits
v0.34.2
...
make-versi
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a04b8b58e |
@ -82,23 +82,3 @@ jobs:
|
||||
push: true
|
||||
tags: huggingface/accelerate:gpu-deepspeed-release-${{needs.get-version.outputs.version}}
|
||||
|
||||
version-cuda-fp8-transformerengine:
|
||||
name: "Latest Accelerate GPU FP8 TransformerEngine [version]"
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
needs: get-version
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
file: docker/accelerate-gpu/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/accelerate:gpu-fp8-transformerengine-release-${{needs.get-version.outputs.version}}
|
||||
22
.github/workflows/build_docker_images.yml
vendored
22
.github/workflows/build_docker_images.yml
vendored
@ -86,25 +86,3 @@ jobs:
|
||||
huggingface/accelerate:gpu-deepspeed-nightly
|
||||
huggingface/accelerate:gpu-deepspeed-nightly-${{ env.date }}
|
||||
|
||||
latest-cuda-fp8-transformerengine:
|
||||
name: "Latest Accelerate GPU FP8 TransformerEngine [dev]"
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
- name: Get current date
|
||||
id: date
|
||||
run: |
|
||||
echo "date=$(date '+%Y-%m-%d')" >> $GITHUB_ENV
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
file: benchmarks/fp8/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/accelerate:gpu-fp8-transformerengine-nightly-${{ env.date }}
|
||||
@ -123,15 +123,12 @@ Follow these steps to start contributing:
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ pip install -e ".[dev]"
|
||||
$ pip install -e ".[quality]"
|
||||
```
|
||||
|
||||
This will install all testing and linting/code quality dependencies for the library (see `quality`, `test_dev`,
|
||||
`test_prod` targets in [`setup.py`](./setup.py)).
|
||||
|
||||
(If accelerate was already installed in the virtual environment, remove
|
||||
it with `pip uninstall accelerate` before reinstalling it in editable
|
||||
mode with the `-e` flag).
|
||||
mode with the `-e` flag.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
@ -157,8 +157,6 @@ accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py
|
||||
|
||||
To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli).
|
||||
|
||||
Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/)
|
||||
|
||||
## Launching multi-CPU run using MPI
|
||||
|
||||
🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
|
||||
@ -258,7 +256,7 @@ pip install accelerate
|
||||
- multi-GPU on several nodes (machines)
|
||||
- TPU
|
||||
- FP16/BFloat16 mixed precision
|
||||
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/)
|
||||
- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine)
|
||||
- DeepSpeed support (Experimental)
|
||||
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
|
||||
- Megatron-LM support (Experimental)
|
||||
|
||||
@ -15,8 +15,6 @@ To run them, it's recommended to use a docker image (see the attached `Dockerfil
|
||||
|
||||
## Running:
|
||||
|
||||
There are official Docker images located at `huggingface/accelerate:gpu-fp8-transformerengine-nightly` which can be used.
|
||||
|
||||
You can run all scripts using the core `accelerate launch` command without any `accelerate config` being needed.
|
||||
|
||||
For single GPU, run it via `python`:
|
||||
|
||||
@ -109,8 +109,7 @@ def evaluate_model(model, dataloader, metric, accelerator=None):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
references = batch["labels"]
|
||||
if accelerator is not None and accelerator.num_processes > 1:
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, references))
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||
metric.add_batch(predictions=predictions, references=references)
|
||||
return metric.compute()
|
||||
|
||||
@ -33,7 +33,6 @@ huggingface/accelerate:{accelerator}-{nightly,release}
|
||||
* `cpu`: Comes compiled off of `python:3.9-slim` and is designed for non-CUDA based workloads.
|
||||
* More to come soon
|
||||
* `gpu-deepspeed`: Comes compiled off of the `nvidia/cuda` image and includes core parts like `bitsandbytes` as well as the latest `deepspeed` version. Runs off python 3.10.
|
||||
* `gpu-fp8-transformerengine`: Comes compiled off of `nvcr.io/nvidia/pytorch` and is specifically for running the `benchmarks/fp8` scripts on devices which support FP8 operations using the `TransformerEngine` library (RTX 4090, H100, etc)
|
||||
|
||||
## Nightlies vs Releases
|
||||
|
||||
|
||||
@ -219,9 +219,3 @@ During training, you may want to save the current state of the model, optimizer,
|
||||
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.
|
||||
|
||||
Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.
|
||||
|
||||
<Note>
|
||||
|
||||
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.
|
||||
|
||||
</Note>
|
||||
|
||||
@ -69,10 +69,4 @@ setting the same seed in the main random number generator in all processes.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
|
||||
If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.
|
||||
|
||||
</Note>
|
||||
|
||||
For more details about the internals, see the [Internals page](package_reference/torch_wrappers).
|
||||
|
||||
@ -50,7 +50,7 @@ The `TransformerEngine` can receive many different arguments that customize how
|
||||
|
||||
* `margin`: The margin to use for the gradient scaling.
|
||||
* `interval`: The interval to use for how often the scaling factor is recomputed.
|
||||
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training, `E4M3` for evaluation)
|
||||
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
|
||||
* `amax_history_len`: The length of the history to use for the scaling factor computation
|
||||
* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
|
||||
* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
|
||||
|
||||
@ -53,8 +53,6 @@ accelerate launch path_to_script.py --args_for_the_script
|
||||
|
||||
To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts.
|
||||
|
||||
We also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run.
|
||||
|
||||
## Adapt training code
|
||||
|
||||
The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups.
|
||||
|
||||
@ -56,7 +56,7 @@ fp8_config:
|
||||
amax_compute_algorithm: max
|
||||
amax_history_length: 1024
|
||||
backend: TE
|
||||
fp8_format: HYBRID
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
@ -117,7 +117,7 @@ fp8_config:
|
||||
amax_compute_algorithm: max
|
||||
amax_history_length: 1024
|
||||
backend: TE
|
||||
fp8_format: HYBRID
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
|
||||
@ -208,13 +208,23 @@ To run it in each of these various modes, use the following commands:
|
||||
|
||||
- [huggan project](https://github.com/huggingface/community-events/tree/main/huggan)
|
||||
|
||||
|
||||
### Using AWS SageMaker integration
|
||||
- [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker)
|
||||
|
||||
## Configuration zoo
|
||||
In [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn
|
||||
how to create your own configuration files depending on the scenario.
|
||||
|
||||
## Simple Multi-GPU Hardware Launcher
|
||||
|
||||
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
|
||||
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
|
||||
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
|
||||
run the script to automatically launch multi GPU training on remote hardware.
|
||||
|
||||
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
|
||||
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
|
||||
with `pip install runhouse`, and you can refer to
|
||||
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
|
||||
for hardware setup instructions, or this
|
||||
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
|
||||
|
||||
## SLURM Scripts
|
||||
In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager.
|
||||
@ -241,20 +251,6 @@ export PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPAT
|
||||
export GPUS_PER_NODE=4
|
||||
```
|
||||
|
||||
## Simple Multi-GPU Hardware Launcher (using an external platform)
|
||||
|
||||
[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate
|
||||
on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can
|
||||
easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then
|
||||
run the script to automatically launch multi GPU training on remote hardware.
|
||||
|
||||
This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own
|
||||
cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed
|
||||
with `pip install runhouse`, and you can refer to
|
||||
[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup)
|
||||
for hardware setup instructions, or this
|
||||
[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough.
|
||||
|
||||
## Finer Examples
|
||||
|
||||
While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations.
|
||||
|
||||
@ -217,7 +217,6 @@ def training_function(config, args):
|
||||
# And call it at the end with no arguments
|
||||
# Note: You could also refactor this outside of your training loop function
|
||||
inner_training_loop()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -19,10 +19,9 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||
|
||||
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -126,8 +125,7 @@ def training_function(config, args):
|
||||
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
|
||||
config["num_epochs"] = 2
|
||||
# Initialize accelerator
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
num_epochs = int(config["num_epochs"])
|
||||
@ -219,11 +217,8 @@ def training_function(config, args):
|
||||
model.train()
|
||||
# New Code #
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
|
||||
if not args.use_stateful_dataloader:
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
else:
|
||||
active_dataloader = train_dataloader
|
||||
# We need to skip steps until we reach the resumed step
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
overall_step += resume_step
|
||||
else:
|
||||
# After the first iteration though, we need to go back to the original dataloader
|
||||
@ -253,6 +248,7 @@ def training_function(config, args):
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
|
||||
@ -265,6 +261,7 @@ def training_function(config, args):
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
@ -279,7 +276,6 @@ def training_function(config, args):
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
@ -312,11 +308,6 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
|
||||
training_function(config, args)
|
||||
|
||||
@ -255,7 +255,6 @@ def training_function(config, args):
|
||||
preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)
|
||||
test_metric = metric.compute(predictions=preds, references=test_references)
|
||||
accelerator.print("Average test metrics from all folds:", test_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -192,7 +192,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -716,7 +716,6 @@ def main():
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -222,7 +222,6 @@ def training_function(config, args):
|
||||
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -399,7 +399,8 @@ def training_function(config, args):
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -197,7 +197,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -202,7 +202,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -703,7 +703,6 @@ def main():
|
||||
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump({"perplexity": perplexity}, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -210,7 +210,6 @@ def training_function(config, args):
|
||||
# And call it at the end with no arguments
|
||||
# Note: You could also refactor this outside of your training loop function
|
||||
inner_training_loop()
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -214,7 +214,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -203,7 +203,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -202,7 +202,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -236,7 +236,11 @@ def training_function(config, args):
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
# New Code #
|
||||
# When a run is finished, you should call `accelerator.end_training()`
|
||||
# to close all of the open trackers
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -23,7 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||
|
||||
from accelerate import Accelerator, DataLoaderConfiguration
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -72,19 +72,12 @@ class PetsDataset(Dataset):
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
if args.with_tracking:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="all",
|
||||
project_dir=args.project_dir,
|
||||
dataloader_config=dataloader_config,
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
|
||||
)
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
@ -269,7 +262,8 @@ def training_function(config, args):
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
accelerator.end_training()
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
@ -304,11 +298,6 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
|
||||
@ -21,7 +21,7 @@ from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||
|
||||
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -49,19 +49,12 @@ EVAL_BATCH_SIZE = 32
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
if args.with_tracking:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu,
|
||||
mixed_precision=args.mixed_precision,
|
||||
dataloader_config=dataloader_config,
|
||||
log_with="all",
|
||||
project_dir=args.project_dir,
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
|
||||
)
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
if args.checkpointing_steps == "epoch":
|
||||
@ -201,10 +194,7 @@ def training_function(config, args):
|
||||
total_loss = 0
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if not args.use_stateful_dataloader:
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
else:
|
||||
active_dataloader = train_dataloader
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
overall_step += resume_step
|
||||
else:
|
||||
# After the first iteration though, we need to go back to the original dataloader
|
||||
@ -266,7 +256,8 @@ def training_function(config, args):
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
accelerator.end_training()
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
@ -293,11 +284,6 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
# Config Zoo
|
||||
|
||||
This folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can use these
|
||||
direct config YAML's, or build off of them for your own YAML's.
|
||||
|
||||
These are highly annoted versions, aiming to teach you what each section does.
|
||||
|
||||
Each config can be run via `accelerate launch --config_file {file} run_me.py`
|
||||
|
||||
`run_me.py` will then print out how the current environment is setup (the contents of the `AcceleratorState`)
|
||||
@ -1,15 +0,0 @@
|
||||
# Similar to FSDP, we set the distributed type as DEEPSPEED
|
||||
distributed_type: DEEPSPEED
|
||||
# With DeepSpeed, we utilize a deepspeed config file for the entire configuration
|
||||
deepspeed_config:
|
||||
# Can also be any of the config json's in accelerate/examples/deepspeed_config_templates
|
||||
deepspeed_config_file: ../deepspeed_config_templates/zero_stage1_config.json
|
||||
# If using ZeRO-3 and wanting to load big models in, this should be set to `true` so
|
||||
# `transformers` uses the right `init` function
|
||||
zero3_init_flag: false # true
|
||||
|
||||
# Finally we need to specify the number of GPUs to use
|
||||
num_processes: 2
|
||||
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
|
||||
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file
|
||||
# mixed_precision: "bf16"
|
||||
@ -1,18 +0,0 @@
|
||||
# This config template simply setups up the TransformersEngine config (and a config for a single GPU),
|
||||
# this can interop with the other configs in this folder
|
||||
distributed_type: "NO"
|
||||
mixed_precision: "fp8"
|
||||
# Then we specify the fp8 configuration:
|
||||
fp8_config:
|
||||
backend: TE # Can be TE | MS-AMP
|
||||
# The following are TE specific arguments.
|
||||
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details
|
||||
amax_history_length: 1024
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: false
|
||||
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
|
||||
use_autocast_during_eval: false
|
||||
# If using MS-AMP, we ignore all of the prior and set a opt_level
|
||||
#opt_level: O1
|
||||
@ -1,18 +0,0 @@
|
||||
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
|
||||
distributed_type: FSDP
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
|
||||
mixed_precision: 'bf16'
|
||||
# Specify the number of GPUs to use
|
||||
num_processes: 2
|
||||
# Then we can specify the FSDP config
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
@ -1,6 +0,0 @@
|
||||
# Specify distributed_type as `MULTI_GPU` for DDP
|
||||
distributed_type: "MULTI_GPU"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
# Specify the number of GPUs to use
|
||||
num_processes: 2
|
||||
@ -1,16 +0,0 @@
|
||||
# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this folder
|
||||
# Generally it's recommended to look at the SLURM config template for a more robust multi-node setup
|
||||
distributed_type: MULTI_GPU
|
||||
# We need to specify the current machine's rank
|
||||
machine_rank: 0
|
||||
# We then need to specify the IP address and port of the main process
|
||||
main_process_ip: '1234'
|
||||
main_process_port: 9999
|
||||
# We need to specify the number of machines
|
||||
num_machines: 2
|
||||
# We need to specify the *total* number of processes
|
||||
num_processes: 8
|
||||
# And then we need to specify how rdvz comms will be handled
|
||||
rdzv_backend: static # or c10d
|
||||
# If the compute nodes are on the same network (cloud will more than likely be false)
|
||||
same_network: false
|
||||
@ -1,26 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A base script which outputs the accelerate config for the given environment
|
||||
"""
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}")
|
||||
if accelerator.fp8_recipe_handler is not None:
|
||||
accelerator.print(f"FP8 config:\n{accelerator.fp8_recipe_handler}")
|
||||
accelerator.end_training()
|
||||
@ -1,4 +0,0 @@
|
||||
# Since this is single GPU, we don't need distributed training
|
||||
distributed_type: "NO"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
@ -180,7 +180,6 @@ def training_function(config, args):
|
||||
eval_metric = accurate.item() / num_elems
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -32,7 +32,7 @@ model.eval()
|
||||
input = torch.randint(
|
||||
low=0,
|
||||
high=model.config.vocab_size,
|
||||
size=(1, 512), # bs x seq_len
|
||||
size=(2, 512), # bs x seq_len
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
@ -49,16 +49,6 @@ model = prepare_pippy(model, split_points="auto", example_args=(input,))
|
||||
# available on all GPUs
|
||||
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
|
||||
|
||||
# Create new inputs of the expected size (n_processes)
|
||||
input = torch.randint(
|
||||
low=0,
|
||||
high=model.config.vocab_size,
|
||||
size=(2, 512), # bs x seq_len
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Move the inputs to the first device
|
||||
input = input.to("cuda:0")
|
||||
|
||||
@ -86,4 +76,3 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
@ -32,7 +32,7 @@ model.eval()
|
||||
input = torch.randint(
|
||||
low=0,
|
||||
high=model.config.vocab_size,
|
||||
size=(1, 1024), # bs x seq_len
|
||||
size=(2, 1024), # bs x seq_len
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
@ -48,16 +48,6 @@ model = prepare_pippy(model, split_points="auto", example_args=(input,))
|
||||
# available on all GPUs
|
||||
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
|
||||
|
||||
# Create new inputs of the expected size (n_processes)
|
||||
input = torch.randint(
|
||||
low=0,
|
||||
high=model.config.vocab_size,
|
||||
size=(2, 1024), # bs x seq_len
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Move the inputs to the first device
|
||||
input = input.to("cuda:0")
|
||||
|
||||
@ -85,4 +75,3 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
@ -27,7 +27,7 @@ model.eval()
|
||||
# Input configs
|
||||
# Create example inputs for the model
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process
|
||||
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
||||
|
||||
@ -43,8 +43,6 @@ model = prepare_pippy(model, split_points="auto", example_kwargs=inputs)
|
||||
|
||||
# currently we don't support `model.generate`
|
||||
# output = model.generate(**inputs, max_new_tokens=1)
|
||||
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
||||
inputs = inputs.to(0)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
@ -54,4 +52,3 @@ if PartialState().is_last_process:
|
||||
next_token_logits = output[0][:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
print(tokenizer.batch_decode(next_token))
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
@ -14,21 +14,12 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from accelerate import PartialState, prepare_pippy
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
|
||||
if version.parse(accelerate_version) > version.parse("0.33.0"):
|
||||
raise RuntimeError(
|
||||
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
|
||||
"Please use a lower accelerate version and `torchpippy`, which this example uses."
|
||||
)
|
||||
|
||||
|
||||
# Set the random seed to have reproducable outputs
|
||||
set_seed(42)
|
||||
|
||||
@ -96,4 +87,3 @@ if PartialState().is_last_process:
|
||||
output = torch.stack(tuple(output[0]))
|
||||
print(f"Time of first pass: {first_batch}")
|
||||
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
@ -185,7 +185,6 @@ def training_function(config, args):
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --job-name=multinode
|
||||
#SBATCH -D .
|
||||
#SBATCH --output=O-%x.%j
|
||||
#SBATCH --error=E-%x.%j
|
||||
#SBATCH --nodes=4 # number of nodes
|
||||
#SBATCH --ntasks-per-node=1 # number of MP tasks
|
||||
#SBATCH --gres=gpu:4 # number of GPUs per node
|
||||
#SBATCH --cpus-per-task=160 # number of cores per tasks
|
||||
#SBATCH --time=01:59:00 # maximum execution time (HH:MM:SS)
|
||||
|
||||
######################
|
||||
### Set enviroment ###
|
||||
######################
|
||||
source activateEnvironment.sh
|
||||
export GPUS_PER_NODE=4
|
||||
######################
|
||||
|
||||
######################
|
||||
#### Set network #####
|
||||
######################
|
||||
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
######################
|
||||
export ACCELERATE_DIR="${ACCELERATE_DIR:-/accelerate}"
|
||||
|
||||
export LAUNCHER="accelerate launch \
|
||||
--config ${ACCELERATE_DIR}/examples/slurm/fsdp_config.yaml \
|
||||
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
|
||||
--num_machines $SLURM_NNODES \
|
||||
--rdzv_backend c10d \
|
||||
--main_process_ip $head_node_ip \
|
||||
--main_process_port 29500 \
|
||||
"
|
||||
export SCRIPT="${ACCELERATE_DIR}/examples/complete_nlp_example.py"
|
||||
export SCRIPT_ARGS=" \
|
||||
--mixed_precision fp16 \
|
||||
--output_dir ${ACCELERATE_DIR}/examples/output \
|
||||
"
|
||||
|
||||
# This step is necessary because accelerate launch does not handle multiline arguments properly
|
||||
export CMD="$LAUNCHER $SCRIPT $SCRIPT_ARGS"
|
||||
srun $CMD
|
||||
5
setup.py
5
setup.py
@ -27,7 +27,6 @@ extras["test_dev"] = [
|
||||
"datasets",
|
||||
"diffusers",
|
||||
"evaluate",
|
||||
"torchdata>=0.8.0",
|
||||
"torchpippy>=0.2.0",
|
||||
"transformers",
|
||||
"scipy",
|
||||
@ -49,7 +48,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.34.2",
|
||||
version="0.34.0.dev0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
@ -71,7 +70,7 @@ setup(
|
||||
},
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=[
|
||||
"numpy>=1.17,<3.0.0",
|
||||
"numpy>=1.17,<2.0.0",
|
||||
"packaging>=20.0",
|
||||
"psutil",
|
||||
"pyyaml",
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "0.34.2"
|
||||
__version__ = "0.34.0.dev0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
|
||||
@ -166,8 +166,8 @@ class Accelerator:
|
||||
Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
|
||||
etc...).
|
||||
mixed_precision (`str`, *optional*):
|
||||
Whether or not to use mixed precision training. Choose from 'no','fp16','bf16' or 'fp8'. Will default to
|
||||
the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
|
||||
Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the
|
||||
value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
|
||||
accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8'
|
||||
requires the installation of transformers-engine.
|
||||
gradient_accumulation_steps (`int`, *optional*, default to 1):
|
||||
@ -310,7 +310,7 @@ class Accelerator:
|
||||
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
|
||||
fsdp_plugin, FullyShardedDataParallelPlugin
|
||||
):
|
||||
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
|
||||
if is_torch_version("<", FSDP_PYTORCH_VERSION):
|
||||
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
|
||||
|
||||
if fsdp_plugin is None: # init from env variables
|
||||
@ -583,12 +583,6 @@ class Accelerator:
|
||||
def non_blocking(self):
|
||||
return self.dataloader_config.non_blocking
|
||||
|
||||
@property
|
||||
def use_stateful_dataloader(self):
|
||||
if hasattr(self.dataloader_config, "use_stateful_dataloader"):
|
||||
return self.dataloader_config.use_stateful_dataloader
|
||||
return False
|
||||
|
||||
@property
|
||||
def project_dir(self):
|
||||
return self.project_configuration.project_dir
|
||||
@ -2074,7 +2068,6 @@ class Accelerator:
|
||||
slice_fn_for_dispatch=slice_fn_for_dispatch,
|
||||
use_seedable_sampler=self.use_seedable_sampler,
|
||||
non_blocking=self.non_blocking,
|
||||
use_stateful_dataloader=self.use_stateful_dataloader,
|
||||
)
|
||||
self._dataloaders.append(prepared_data_loader)
|
||||
return prepared_data_loader
|
||||
@ -2734,7 +2727,9 @@ class Accelerator:
|
||||
for tracker in self.trackers:
|
||||
tracker.finish()
|
||||
|
||||
self.state.destroy_process_group()
|
||||
if torch.distributed.is_initialized():
|
||||
# needed when using torch.distributed.init_process_group
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
def save(self, obj, f, safe_serialization=False):
|
||||
"""
|
||||
|
||||
@ -127,11 +127,6 @@ def save_accelerator_state(
|
||||
sampler = dataloader.get_sampler()
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
||||
if getattr(dataloader, "use_stateful_dataloader", False):
|
||||
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
||||
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
|
||||
state_dict = dataloader.state_dict()
|
||||
torch.save(state_dict, output_dataloader_state_dict_file)
|
||||
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
||||
|
||||
# GradScaler state
|
||||
@ -246,12 +241,6 @@ def load_accelerator_state(
|
||||
sampler = dataloader.get_sampler()
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
|
||||
if getattr(dataloader, "use_stateful_dataloader", False):
|
||||
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
||||
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
|
||||
if input_dataloader_state_dict_file.exists():
|
||||
state_dict = torch.load(input_dataloader_state_dict_file)
|
||||
dataloader.load_state_dict(state_dict)
|
||||
logger.info("All dataloader sampler states loaded successfully")
|
||||
|
||||
# GradScaler state
|
||||
|
||||
@ -735,8 +735,8 @@ def get_cluster_input():
|
||||
)
|
||||
fp8_config["fp8_format"] = _ask_options(
|
||||
"Which weight format should be used?",
|
||||
["HYBRID", "E4M3"],
|
||||
lambda x: "HYBRID" if x == 0 else "E4M3",
|
||||
["E4M3", "HYBRID"],
|
||||
lambda x: "E4M3" if x == 0 else "HYBRID",
|
||||
default=0,
|
||||
)
|
||||
fp8_config["amax_history_length"] = _ask_field(
|
||||
|
||||
@ -99,17 +99,13 @@ class BaseConfig:
|
||||
result = {k: v for k, v in result.items() if v is not None}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def process_config(config_dict):
|
||||
"""
|
||||
Processes `config_dict` and sets default values for any missing keys
|
||||
"""
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file=None):
|
||||
json_file = default_json_config_file if json_file is None else json_file
|
||||
with open(json_file, encoding="utf-8") as f:
|
||||
config_dict = json.load(f)
|
||||
if "compute_environment" not in config_dict:
|
||||
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
||||
if "distributed_type" not in config_dict:
|
||||
raise ValueError("A `distributed_type` must be specified in the config file.")
|
||||
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
|
||||
config_dict["num_processes"] = 1
|
||||
if "mixed_precision" not in config_dict:
|
||||
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
||||
if "fp16" in config_dict: # Convert the config to the new format.
|
||||
@ -123,14 +119,6 @@ class BaseConfig:
|
||||
config_dict["debug"] = False
|
||||
if "enable_cpu_affinity" not in config_dict:
|
||||
config_dict["enable_cpu_affinity"] = False
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file=None):
|
||||
json_file = default_json_config_file if json_file is None else json_file
|
||||
with open(json_file, encoding="utf-8") as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict = cls.process_config(config_dict)
|
||||
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
||||
if len(extra_keys) > 0:
|
||||
raise ValueError(
|
||||
@ -150,7 +138,23 @@ class BaseConfig:
|
||||
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
||||
with open(yaml_file, encoding="utf-8") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
config_dict = cls.process_config(config_dict)
|
||||
if "compute_environment" not in config_dict:
|
||||
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
||||
if "mixed_precision" not in config_dict:
|
||||
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
||||
if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]:
|
||||
config_dict["mixed_precision"] = "no"
|
||||
if "fp16" in config_dict: # Convert the config to the new format.
|
||||
del config_dict["fp16"]
|
||||
if "dynamo_backend" in config_dict: # Convert the config to the new format.
|
||||
dynamo_backend = config_dict.pop("dynamo_backend")
|
||||
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
|
||||
if "use_cpu" not in config_dict:
|
||||
config_dict["use_cpu"] = False
|
||||
if "debug" not in config_dict:
|
||||
config_dict["debug"] = False
|
||||
if "enable_cpu_affinity" not in config_dict:
|
||||
config_dict["enable_cpu_affinity"] = False
|
||||
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
||||
if len(extra_keys) > 0:
|
||||
raise ValueError(
|
||||
@ -177,7 +181,7 @@ class BaseConfig:
|
||||
|
||||
@dataclass
|
||||
class ClusterConfig(BaseConfig):
|
||||
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
|
||||
num_processes: int
|
||||
machine_rank: int = 0
|
||||
num_machines: int = 1
|
||||
gpu_ids: Optional[str] = None
|
||||
|
||||
@ -1074,8 +1074,6 @@ def _validate_launch_command(args):
|
||||
# Silently set the default here
|
||||
if args.dynamo_backend is None:
|
||||
args.dynamo_backend = "no"
|
||||
if args.num_processes == -1:
|
||||
raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
|
||||
else:
|
||||
if args.num_processes is None:
|
||||
if args.use_xpu and is_xpu_available():
|
||||
|
||||
@ -20,7 +20,7 @@ import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
||||
|
||||
from .logging import get_logger
|
||||
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
||||
from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available
|
||||
from .utils import (
|
||||
RNGType,
|
||||
broadcast,
|
||||
@ -30,7 +30,6 @@ from .utils import (
|
||||
get_data_structure,
|
||||
initialize_tensors,
|
||||
is_torch_version,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
send_to_device,
|
||||
slice_tensors,
|
||||
synchronize_rng_states,
|
||||
@ -365,13 +364,6 @@ class DataLoaderStateMixin:
|
||||
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
|
||||
batch size
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
|
||||
`self.gradient_state`.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
@ -396,94 +388,9 @@ class DataLoaderStateMixin:
|
||||
self.gradient_state._remove_dataloader(self)
|
||||
|
||||
|
||||
class DataLoaderAdapter:
|
||||
class DataLoaderShard(DataLoader, DataLoaderStateMixin):
|
||||
"""
|
||||
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
||||
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
||||
self.use_stateful_dataloader = use_stateful_dataloader
|
||||
if is_torchdata_stateful_dataloader_available():
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
|
||||
raise ImportError(
|
||||
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
|
||||
)
|
||||
if use_stateful_dataloader:
|
||||
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
||||
else:
|
||||
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
|
||||
|
||||
if hasattr(self.base_dataloader, "state_dict"):
|
||||
self.dl_state_dict = self.base_dataloader.state_dict()
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
|
||||
if name == "base_dataloader":
|
||||
raise AttributeError()
|
||||
# Delegate attribute access to the internal dataloader
|
||||
return getattr(self.base_dataloader, name)
|
||||
|
||||
def state_dict(self):
|
||||
return self.dl_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.base_dataloader.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
"""
|
||||
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
object.
|
||||
"""
|
||||
return self.base_dataloader.__class__
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_dataloader)
|
||||
|
||||
def adjust_state_dict_for_prefetch(self):
|
||||
"""
|
||||
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
||||
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
|
||||
overridden.
|
||||
|
||||
This should modify `self.dl_state_dict` directly
|
||||
"""
|
||||
# The state dict will be off by a factor of `n-1` batch too many during DDP,
|
||||
# so we need to adjust it here
|
||||
if PartialState().distributed_type != DistributedType.NO:
|
||||
factor = PartialState().num_processes - 1
|
||||
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
|
||||
self.dl_state_dict["_sampler_iter_yielded"] -= factor
|
||||
if self.dl_state_dict["_num_yielded"] > 0:
|
||||
self.dl_state_dict["_num_yielded"] -= factor
|
||||
if self.dl_state_dict["_index_sampler_state"] is not None:
|
||||
if (
|
||||
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
|
||||
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
|
||||
):
|
||||
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
|
||||
|
||||
def _update_state_dict(self):
|
||||
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
|
||||
# E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
|
||||
# what it wants to yield.
|
||||
#
|
||||
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
|
||||
if hasattr(self.base_dataloader, "state_dict"):
|
||||
self.dl_state_dict = self.base_dataloader.state_dict()
|
||||
# Potentially modify the state_dict to adjust for prefetching
|
||||
self.adjust_state_dict_for_prefetch()
|
||||
# Then tag if we are at the end of the dataloader
|
||||
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
|
||||
|
||||
|
||||
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
"""
|
||||
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
|
||||
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.dataset.Dataset`):
|
||||
@ -502,8 +409,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
A random number generator to keep synchronized across processes.
|
||||
skip_batches (`int`, *optional*, defaults to 0):
|
||||
The number of batches to skip at the beginning.
|
||||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
||||
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
||||
**kwargs (additional keyword arguments, *optional*):
|
||||
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
||||
|
||||
@ -523,12 +428,11 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
rng_types=None,
|
||||
synchronized_generator=None,
|
||||
skip_batches=0,
|
||||
use_stateful_dataloader=False,
|
||||
_drop_last: bool = False,
|
||||
_non_blocking: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.device = device
|
||||
self.rng_types = rng_types
|
||||
self.synchronized_generator = synchronized_generator
|
||||
@ -544,7 +448,7 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
self.begin()
|
||||
|
||||
self.set_epoch(self.iteration)
|
||||
dataloader_iter = self.base_dataloader.__iter__()
|
||||
dataloader_iter = super().__iter__()
|
||||
# We iterate one batch ahead to check when we are at the end
|
||||
try:
|
||||
current_batch = next(dataloader_iter)
|
||||
@ -557,7 +461,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
# But we still move it to the device so it is done before `StopIteration` is reached
|
||||
if self.device is not None:
|
||||
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
|
||||
self._update_state_dict()
|
||||
next_batch = next(dataloader_iter)
|
||||
if batch_index >= self.skip_batches:
|
||||
yield current_batch
|
||||
@ -565,7 +468,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
current_batch = next_batch
|
||||
except StopIteration:
|
||||
self.end_of_dataloader = True
|
||||
self._update_state_dict()
|
||||
if batch_index >= self.skip_batches:
|
||||
yield current_batch
|
||||
break
|
||||
@ -573,15 +475,6 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
self.iteration += 1
|
||||
self.end()
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be
|
||||
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (DataLoaderShard, *args[1:])
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
# In case it is manually passed in, the user can set it to what they like
|
||||
if self.iteration != epoch:
|
||||
@ -654,10 +547,6 @@ if is_torch_xla_available():
|
||||
|
||||
return super().__iter__()
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
if hasattr(self.dataloader, "set_epoch"):
|
||||
self.dataloader.set_epoch(epoch)
|
||||
|
||||
@property
|
||||
def total_batch_size(self):
|
||||
return self._loader.total_batch_size
|
||||
@ -675,10 +564,10 @@ if is_torch_xla_available():
|
||||
return self._loader
|
||||
|
||||
|
||||
class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
|
||||
"""
|
||||
Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
|
||||
their part of the batch.
|
||||
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
|
||||
process their part of the batch.
|
||||
|
||||
Args:
|
||||
split_batches (`bool`, *optional*, defaults to `False`):
|
||||
@ -690,8 +579,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
size of the `dataloader` is a round multiple of `batch_size`.
|
||||
skip_batches (`int`, *optional*, defaults to 0):
|
||||
The number of batches to skip at the beginning of an iteration.
|
||||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
||||
Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
|
||||
|
||||
**Available attributes:**
|
||||
|
||||
@ -707,7 +594,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
dataset,
|
||||
split_batches: bool = False,
|
||||
skip_batches=0,
|
||||
use_stateful_dataloader=False,
|
||||
_drop_last: bool = False,
|
||||
_non_blocking: bool = False,
|
||||
slice_fn=None,
|
||||
@ -720,13 +606,13 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
# We need to save the shuffling state of the DataPipe
|
||||
if isinstance(dataset, ShufflerIterDataPipe):
|
||||
shuffle = dataset._shuffle_enabled
|
||||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.split_batches = split_batches
|
||||
if shuffle:
|
||||
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
||||
|
||||
self.gradient_state = GradientState()
|
||||
self.state = PartialState()
|
||||
self.state = AcceleratorState()
|
||||
self._drop_last = _drop_last
|
||||
self._non_blocking = _non_blocking
|
||||
self.skip_batches = skip_batches
|
||||
@ -741,14 +627,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
try:
|
||||
if self.split_batches:
|
||||
# One batch of the main iterator is dispatched and split.
|
||||
self._update_state_dict()
|
||||
batch = next(iterator)
|
||||
else:
|
||||
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
||||
# We add the batches one by one so we have the remainder available when drop_last=False.
|
||||
batches = []
|
||||
for _ in range(self.state.num_processes):
|
||||
self._update_state_dict()
|
||||
batches.append(next(iterator))
|
||||
try:
|
||||
batch = concatenate(batches, dim=0)
|
||||
@ -789,9 +673,9 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
|
||||
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
|
||||
# But, we only iterate through the DataLoader on process 0.
|
||||
main_iterator = self.base_dataloader.__iter__()
|
||||
main_iterator = super().__iter__()
|
||||
elif self.state.process_index == 0:
|
||||
main_iterator = self.base_dataloader.__iter__()
|
||||
main_iterator = super().__iter__()
|
||||
stop_iteration = False
|
||||
self._stop_iteration = False
|
||||
first_batch = None
|
||||
@ -849,7 +733,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
|
||||
if stop_iteration:
|
||||
self.end_of_dataloader = True
|
||||
self._update_state_dict()
|
||||
self.remainder = observed_batch_size
|
||||
if batch_index >= self.skip_batches:
|
||||
yield batch
|
||||
@ -861,13 +744,13 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
# In case it is manually passed in, the user can set it to what they like
|
||||
if self.iteration != epoch:
|
||||
self.iteration = epoch
|
||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
|
||||
if hasattr(self.batch_sampler.sampler, "set_epoch"):
|
||||
self.batch_sampler.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __len__(self):
|
||||
whole_length = len(self.base_dataloader)
|
||||
whole_length = super().__len__()
|
||||
if self.split_batches:
|
||||
return whole_length
|
||||
elif self._drop_last:
|
||||
@ -875,15 +758,6 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
else:
|
||||
return math.ceil(whole_length / self.state.num_processes)
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to
|
||||
be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (DataLoaderDispatcher, *args[1:])
|
||||
|
||||
@property
|
||||
def total_batch_size(self):
|
||||
return (
|
||||
@ -938,7 +812,6 @@ def prepare_data_loader(
|
||||
slice_fn_for_dispatch: Optional[Callable] = None,
|
||||
use_seedable_sampler: bool = False,
|
||||
non_blocking: bool = False,
|
||||
use_stateful_dataloader: bool = False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
|
||||
@ -952,9 +825,10 @@ def prepare_data_loader(
|
||||
device (`torch.device`):
|
||||
The target device for the returned `DataLoader`.
|
||||
num_processes (`int`, *optional*):
|
||||
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
|
||||
The number of processes running concurrently. Will default to the value given by
|
||||
[`~state.AcceleratorState`].
|
||||
process_index (`int`, *optional*):
|
||||
The index of the current process. Will default to the value given by [`~state.PartialState`].
|
||||
The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
|
||||
split_batches (`bool`, *optional*, defaults to `False`):
|
||||
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
||||
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
||||
@ -999,10 +873,6 @@ def prepare_data_loader(
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
|
||||
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
|
||||
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
|
||||
"If set to true, the dataloader prepared by the Accelerator will be backed by "
|
||||
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
|
||||
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
|
||||
|
||||
|
||||
Returns:
|
||||
@ -1023,8 +893,8 @@ def prepare_data_loader(
|
||||
|
||||
if dispatch_batches and not put_on_device:
|
||||
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
|
||||
# Grab defaults from PartialState
|
||||
state = PartialState()
|
||||
# Grab defaults from AcceleratorState
|
||||
state = AcceleratorState()
|
||||
if num_processes is None:
|
||||
num_processes = state.num_processes
|
||||
if process_index is None:
|
||||
@ -1136,7 +1006,6 @@ def prepare_data_loader(
|
||||
_drop_last=dataloader.drop_last,
|
||||
_non_blocking=non_blocking,
|
||||
slice_fn=slice_fn_for_dispatch,
|
||||
use_stateful_dataloader=use_stateful_dataloader,
|
||||
**kwargs,
|
||||
)
|
||||
elif sampler_is_batch_sampler:
|
||||
@ -1149,7 +1018,6 @@ def prepare_data_loader(
|
||||
_drop_last=dataloader.drop_last,
|
||||
_non_blocking=non_blocking,
|
||||
synchronized_generator=synchronized_generator,
|
||||
use_stateful_dataloader=use_stateful_dataloader,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@ -1161,7 +1029,6 @@ def prepare_data_loader(
|
||||
synchronized_generator=synchronized_generator,
|
||||
_drop_last=dataloader.drop_last,
|
||||
_non_blocking=non_blocking,
|
||||
use_stateful_dataloader=use_stateful_dataloader,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1175,7 +1042,6 @@ def prepare_data_loader(
|
||||
class SkipBatchSampler(BatchSampler):
|
||||
"""
|
||||
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
||||
Should not be used if the original dataloader is a `StatefulDataLoader`.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_sampler, skip_batches=0):
|
||||
@ -1195,10 +1061,9 @@ class SkipBatchSampler(BatchSampler):
|
||||
return len(self.batch_sampler) - self.skip_batches
|
||||
|
||||
|
||||
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
class SkipDataLoader(DataLoader):
|
||||
"""
|
||||
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
|
||||
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
|
||||
Subclass of a PyTorch `DataLoader` that will skip the first batches.
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.dataset.Dataset`):
|
||||
@ -1209,36 +1074,19 @@ class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
All other keyword arguments to pass to the regular `DataLoader` initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
|
||||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
||||
def __init__(self, dataset, skip_batches=0, **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.skip_batches = skip_batches
|
||||
self.gradient_state = GradientState()
|
||||
|
||||
def __iter__(self):
|
||||
self.begin()
|
||||
for index, batch in enumerate(self.base_dataloader.__iter__()):
|
||||
for index, batch in enumerate(super().__iter__()):
|
||||
if index >= self.skip_batches:
|
||||
self._update_state_dict()
|
||||
yield batch
|
||||
self.end()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_dataloader) - self.skip_batches
|
||||
|
||||
def __reduce__(self):
|
||||
"""
|
||||
Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be
|
||||
explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its
|
||||
`__class__` member.
|
||||
"""
|
||||
args = super().__reduce__()
|
||||
return (SkipDataLoader, *args[1:])
|
||||
|
||||
|
||||
def skip_first_batches(dataloader, num_batches=0):
|
||||
"""
|
||||
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
|
||||
the original dataloader is a `StatefulDataLoader`.
|
||||
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
|
||||
"""
|
||||
state = PartialState()
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
|
||||
@ -79,21 +79,22 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
|
||||
`AcceleratorState.num_processes`
|
||||
"""
|
||||
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
|
||||
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
|
||||
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
|
||||
from pippy.PipelineStage import PipelineStage
|
||||
|
||||
# We need to annotate the split points in the model for PiPPy
|
||||
state = PartialState()
|
||||
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
|
||||
pipe = pipeline(
|
||||
model,
|
||||
mb_args=args,
|
||||
mb_kwargs=kwargs,
|
||||
split_spec=split_spec,
|
||||
)
|
||||
stage = pipe.build_stage(state.local_process_index, device=state.device)
|
||||
schedule = ScheduleGPipe(stage, num_chunks)
|
||||
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
|
||||
found_batch_size = find_pippy_batch_size(args, kwargs)
|
||||
if found_batch_size != num_chunks:
|
||||
if args is not None:
|
||||
args = pad_input_tensors(args, found_batch_size, num_chunks)
|
||||
if kwargs is not None:
|
||||
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
|
||||
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
|
||||
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
|
||||
|
||||
return schedule
|
||||
return stage
|
||||
|
||||
|
||||
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
|
||||
@ -142,12 +143,11 @@ def prepare_pippy(
|
||||
no_split_module_classes (`List[str]`):
|
||||
A list of class names for layers we don't want to be split.
|
||||
example_args (tuple of model inputs):
|
||||
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
|
||||
this method if possible.
|
||||
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
|
||||
example_kwargs (dict of model inputs)
|
||||
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
|
||||
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
|
||||
recommended unless the prior condition is true for all cases.
|
||||
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
|
||||
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
|
||||
is true for all cases.
|
||||
num_chunks (`int`, defaults to the number of available GPUs):
|
||||
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
|
||||
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
|
||||
@ -155,7 +155,10 @@ def prepare_pippy(
|
||||
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
|
||||
"""
|
||||
if not is_pippy_available():
|
||||
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
|
||||
raise ImportError(
|
||||
"`pippy` was not found to be installed on your system. Please "
|
||||
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
|
||||
)
|
||||
state = PartialState()
|
||||
example_args = send_to_device(example_args, "cpu")
|
||||
example_kwargs = send_to_device(example_kwargs, "cpu")
|
||||
@ -174,7 +177,7 @@ def prepare_pippy(
|
||||
model.hf_split_points = split_points
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
|
||||
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)
|
||||
|
||||
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
||||
# Note: creates an infinite recursion loop with `generate`
|
||||
|
||||
@ -125,15 +125,13 @@ class AcceleratedOptimizer(torch.optim.Optimizer):
|
||||
"""
|
||||
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
|
||||
"""
|
||||
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
|
||||
self.optimizer.train()
|
||||
return self.optimizer.train()
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
|
||||
"""
|
||||
if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
|
||||
self.optimizer.eval()
|
||||
return self.optimizer.eval()
|
||||
|
||||
def step(self, closure=None):
|
||||
if is_lomo_available():
|
||||
|
||||
@ -695,14 +695,12 @@ class PartialState:
|
||||
return torch.device("mlu")
|
||||
elif is_musa_available():
|
||||
return torch.device("musa")
|
||||
# NPU should be checked before CUDA when using `transfer_to_npu`
|
||||
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
||||
elif is_npu_available():
|
||||
return torch.device("npu")
|
||||
elif torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif is_xpu_available():
|
||||
return torch.device("xpu:0")
|
||||
elif is_npu_available():
|
||||
return torch.device("npu")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
@ -726,15 +724,13 @@ class PartialState:
|
||||
elif is_musa_available():
|
||||
backend = "mccl"
|
||||
distributed_type = DistributedType.MULTI_MUSA
|
||||
# NPU should be checked before CUDA when using `transfer_to_npu`
|
||||
# See issue #3020: https://github.com/huggingface/accelerate/issues/3020
|
||||
elif is_npu_available():
|
||||
backend = "hccl"
|
||||
distributed_type = DistributedType.MULTI_NPU
|
||||
elif torch.cuda.is_available():
|
||||
if backend is None:
|
||||
backend = "nccl"
|
||||
distributed_type = DistributedType.MULTI_GPU
|
||||
elif is_npu_available():
|
||||
backend = "hccl"
|
||||
distributed_type = DistributedType.MULTI_NPU
|
||||
|
||||
if distributed_type is None and (
|
||||
int(os.environ.get("LOCAL_RANK", -1)) != -1
|
||||
@ -789,16 +785,6 @@ class PartialState:
|
||||
self.device = torch.device(device, device_index)
|
||||
device_module.set_device(self.device)
|
||||
|
||||
def destroy_process_group(self, group=None):
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
"""
|
||||
if self.fork_launched and group is None:
|
||||
return
|
||||
# needed when using torch.distributed.init_process_group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group(group)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
# By this point we know that no attributes of `self` contain `name`,
|
||||
# so we just modify the error message
|
||||
@ -993,18 +979,6 @@ class AcceleratorState:
|
||||
if reset_partial_state:
|
||||
PartialState._reset_state()
|
||||
|
||||
def destroy_process_group(self, group=None):
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
|
||||
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
||||
"""
|
||||
PartialState().destroy_process_group(group)
|
||||
|
||||
@property
|
||||
def fork_launched(self):
|
||||
return PartialState().fork_launched
|
||||
|
||||
@property
|
||||
def use_distributed(self):
|
||||
"""
|
||||
|
||||
@ -15,7 +15,6 @@ from .testing import (
|
||||
DEFAULT_LAUNCH_COMMAND,
|
||||
are_the_same_tensors,
|
||||
assert_exception,
|
||||
capture_call_output,
|
||||
device_count,
|
||||
execute_subprocess_async,
|
||||
get_launch_command,
|
||||
|
||||
@ -223,7 +223,6 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
|
||||
json.dump(state, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -294,7 +294,6 @@ def main():
|
||||
if accelerator.is_local_main_process:
|
||||
print("**Test that `drop_last` is taken into account**")
|
||||
test_gather_for_metrics_drop_last()
|
||||
accelerator.end_training()
|
||||
accelerator.state._reset_state()
|
||||
|
||||
|
||||
|
||||
@ -240,7 +240,6 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
|
||||
json.dump(train_total_peak_memory, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -205,7 +205,6 @@ def training_function(config, args):
|
||||
if accelerator.is_main_process:
|
||||
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||
json.dump(performance_metric, f)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -12,19 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torchvision.models import resnet34
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
GPT2Config,
|
||||
GPT2ForSequenceClassification,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from accelerate import PartialState
|
||||
from accelerate.inference import prepare_pippy
|
||||
from accelerate.utils import DistributedType, set_seed
|
||||
from accelerate.utils import DistributedType, send_to_device, set_seed
|
||||
|
||||
|
||||
model_to_config = {
|
||||
"t5": (T5ForConditionalGeneration, T5Config, 1024),
|
||||
"bert": (BertForMaskedLM, BertConfig, 512),
|
||||
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
|
||||
}
|
||||
@ -38,35 +42,23 @@ def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
|
||||
# config_args["pad_token_id"] = 0
|
||||
model_config = config(**config_args)
|
||||
model = initializer(model_config)
|
||||
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
|
||||
trace_input = torch.randint(size=(1, seq_len), **kwargs)
|
||||
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
|
||||
return model, trace_input, inference_inputs
|
||||
|
||||
|
||||
def test_bert(batch_size: int = 2):
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
|
||||
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
|
||||
# For inference args need to be a tuple
|
||||
inputs = inference_inputs.to("cuda")
|
||||
with torch.no_grad():
|
||||
output = model(inputs)
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
if not state.is_last_process:
|
||||
assert output is None, "Output was not generated on just the last process!"
|
||||
else:
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
return model, torch.randint(
|
||||
low=0,
|
||||
high=model_config.vocab_size,
|
||||
size=(num_processes, seq_len),
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
def test_gpt2(batch_size: int = 2):
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
|
||||
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
|
||||
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
|
||||
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
|
||||
# For inference args need to be a tuple
|
||||
inputs = inference_inputs.to("cuda")
|
||||
inputs = inputs.to("cuda")
|
||||
with torch.no_grad():
|
||||
output = model(inputs)
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
@ -76,41 +68,62 @@ def test_gpt2(batch_size: int = 2):
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
|
||||
|
||||
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
|
||||
# def test_resnet(batch_size: int = 2):
|
||||
# set_seed(42)
|
||||
# state = PartialState()
|
||||
# model = resnet34()
|
||||
# input_tensor = torch.rand(1, 3, 224, 224)
|
||||
# model = prepare_pippy(
|
||||
# model,
|
||||
# example_args=(input_tensor,),
|
||||
# )
|
||||
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
|
||||
# inputs = send_to_device(inference_inputs, "cuda:0")
|
||||
# with torch.no_grad():
|
||||
# output = model(inputs)
|
||||
# # Zach: Check that we just grab the real outputs we need at the end
|
||||
# if not state.is_last_process:
|
||||
# assert output is None, "Output was not generated on just the last process!"
|
||||
# else:
|
||||
# assert output is not None, "Output was not generated in the last process!"
|
||||
def test_t5(batch_size: int = 2):
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
|
||||
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
|
||||
model = prepare_pippy(
|
||||
model,
|
||||
no_split_module_classes=model._no_split_modules,
|
||||
example_kwargs=example_inputs,
|
||||
)
|
||||
# For inference args need to be a tuple
|
||||
inputs = send_to_device(example_inputs, "cuda:0")
|
||||
with torch.no_grad():
|
||||
output = model(*inputs.values())
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
if not state.is_last_process:
|
||||
assert output is None, "Output was not generated on just the last process!"
|
||||
else:
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
|
||||
|
||||
def test_resnet(batch_size: int = 2):
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model = resnet34()
|
||||
input_tensor = torch.rand(batch_size, 3, 224, 224)
|
||||
model = prepare_pippy(
|
||||
model,
|
||||
example_args=(input_tensor,),
|
||||
)
|
||||
inputs = send_to_device(input_tensor, "cuda:0")
|
||||
with torch.no_grad():
|
||||
output = model(inputs)
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
if not state.is_last_process:
|
||||
assert output is None, "Output was not generated on just the last process!"
|
||||
else:
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
state = PartialState()
|
||||
state.print("Testing pippy integration...")
|
||||
try:
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
state.print("Testing GPT2...")
|
||||
test_gpt2()
|
||||
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
|
||||
# due to references
|
||||
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
|
||||
# test_gpt2(3)
|
||||
state.print("Testing BERT...")
|
||||
test_bert()
|
||||
else:
|
||||
print("Less than two GPUs found, not running tests!")
|
||||
finally:
|
||||
state.destroy_process_group()
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
state.print("Testing GPT2...")
|
||||
test_gpt2()
|
||||
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
|
||||
# due to references
|
||||
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
|
||||
# test_gpt2(3)
|
||||
state.print("Testing T5...")
|
||||
test_t5()
|
||||
test_t5(1)
|
||||
test_t5(3)
|
||||
state.print("Testing CV model...")
|
||||
test_resnet()
|
||||
test_resnet(3)
|
||||
else:
|
||||
print("Less than two GPUs found, not running tests!")
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState
|
||||
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
|
||||
|
||||
|
||||
class MockModel(torch.nn.Module):
|
||||
@ -71,7 +71,6 @@ def main():
|
||||
]:
|
||||
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
|
||||
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -14,8 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
import tempfile
|
||||
|
||||
import warnings
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
@ -78,17 +77,12 @@ def create_accelerator(even_batches=True):
|
||||
return accelerator
|
||||
|
||||
|
||||
def create_dataloader(
|
||||
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
|
||||
):
|
||||
def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False):
|
||||
"""
|
||||
Create a simple DataLoader to use during the test cases
|
||||
"""
|
||||
values = torch.as_tensor(range(dataset_size))
|
||||
if shuffle:
|
||||
values = values[torch.randperm(values.size(0))]
|
||||
if iterable:
|
||||
dataset = DummyIterableDataset(values)
|
||||
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
|
||||
else:
|
||||
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
|
||||
|
||||
@ -248,16 +242,6 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
|
||||
assert "only supported for map-style datasets" in str(w[-1].message)
|
||||
|
||||
|
||||
def test_pickle_accelerator():
|
||||
accelerator = create_accelerator()
|
||||
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
|
||||
_ = accelerator.prepare(data_loader)
|
||||
pickled_accelerator = pickle.dumps(accelerator)
|
||||
unpickled_accelerator = pickle.loads(pickled_accelerator)
|
||||
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
|
||||
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__
|
||||
|
||||
|
||||
def test_data_loader(data_loader, accelerator):
|
||||
# Prepare the DataLoader
|
||||
data_loader = accelerator.prepare(data_loader)
|
||||
@ -276,81 +260,6 @@ def test_data_loader(data_loader, accelerator):
|
||||
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
|
||||
|
||||
|
||||
def test_stateful_dataloader(accelerator):
|
||||
"""
|
||||
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
|
||||
resumed from the saved state.
|
||||
|
||||
The result should be the same as the rest of the data that iterated over after saving.
|
||||
"""
|
||||
old_dataloader_config = accelerator.dataloader_config
|
||||
try:
|
||||
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
prepared_dl = create_dataloader(
|
||||
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
||||
)
|
||||
untrained_batches = []
|
||||
# Calculate what step that will be
|
||||
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
||||
last_batch_num = total_batches - 1
|
||||
for step, batch in enumerate(prepared_dl):
|
||||
# Step just before
|
||||
if step == last_batch_num - 1:
|
||||
state_dict = prepared_dl.state_dict()
|
||||
if step >= last_batch_num:
|
||||
# Otherwise grab the "unseen" batches
|
||||
untrained_batches.append(batch)
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
prepared_dl.load_state_dict(state_dict)
|
||||
resumed_batches = []
|
||||
for batch in prepared_dl:
|
||||
resumed_batches.append(batch)
|
||||
resumed_batches = accelerator.gather(resumed_batches)
|
||||
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
||||
for v1, v2 in zip(b1, b2):
|
||||
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
||||
finally:
|
||||
accelerator.dataloader_config = old_dataloader_config
|
||||
|
||||
|
||||
def test_stateful_dataloader_save_state(accelerator):
|
||||
"""
|
||||
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
|
||||
and then resumed from the saved state.
|
||||
|
||||
The result should be the same as the rest of the data that iterated over after saving.
|
||||
"""
|
||||
old_dataloader_config = accelerator.dataloader_config
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
prepared_dl = create_dataloader(
|
||||
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
||||
)
|
||||
untrained_batches = []
|
||||
# Calculate what step that will be
|
||||
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
||||
last_batch_num = total_batches - 1
|
||||
for step, batch in enumerate(prepared_dl):
|
||||
# Step just before
|
||||
if step == last_batch_num - 1:
|
||||
accelerator.save_state(tmpdir)
|
||||
if step >= last_batch_num:
|
||||
# Otherwise grab the "unseen" batches
|
||||
untrained_batches.append(batch)
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
accelerator.load_state(tmpdir)
|
||||
resumed_batches = []
|
||||
for batch in prepared_dl:
|
||||
resumed_batches.append(batch)
|
||||
resumed_batches = accelerator.gather(resumed_batches)
|
||||
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
||||
for v1, v2 in zip(b1, b2):
|
||||
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
||||
finally:
|
||||
accelerator.dataloader_config = old_dataloader_config
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = create_accelerator()
|
||||
torch.manual_seed(accelerator.process_index)
|
||||
@ -379,9 +288,6 @@ def main():
|
||||
test_join_raises_warning_for_non_ddp_distributed(accelerator)
|
||||
accelerator.state.distributed_type = original_state
|
||||
|
||||
accelerator.print("Test pickling an accelerator")
|
||||
test_pickle_accelerator()
|
||||
|
||||
dataset = DummyDataset()
|
||||
# Conventional Dataloader with shuffle=False
|
||||
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
|
||||
@ -400,10 +306,6 @@ def main():
|
||||
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
||||
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
|
||||
test_data_loader(loader, accelerator)
|
||||
test_stateful_dataloader(accelerator)
|
||||
test_stateful_dataloader_save_state(accelerator)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -158,4 +158,3 @@ if __name__ == "__main__":
|
||||
if accelerator.is_main_process:
|
||||
shutil.rmtree(out_path)
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
@ -110,8 +110,6 @@ def main():
|
||||
if is_bnb_available():
|
||||
print("Test problematic imports (bnb)")
|
||||
test_problematic_imports()
|
||||
if NUM_PROCESSES > 1:
|
||||
PartialState().destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -173,7 +173,6 @@ def main():
|
||||
test_op_checker(state)
|
||||
state.print("testing sending tensors across devices")
|
||||
test_copy_tensor_to_devices(state)
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -822,8 +822,6 @@ def main():
|
||||
print("\n**Test reinstantiated state**")
|
||||
test_reinstantiated_state()
|
||||
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -20,7 +20,7 @@ from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
|
||||
from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
|
||||
from accelerate.state import GradientState
|
||||
from accelerate.test_utils import RegressionDataset, RegressionModel
|
||||
from accelerate.utils import DistributedType, set_seed
|
||||
@ -249,9 +249,9 @@ def test_gradient_accumulation_with_opt_and_scheduler(
|
||||
split_batches=False, dispatch_batches=False, sync_each_batch=False
|
||||
):
|
||||
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
|
||||
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
|
||||
accelerator = Accelerator(
|
||||
dataloader_config=dataloader_config,
|
||||
split_batches=split_batches,
|
||||
dispatch_batches=dispatch_batches,
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
)
|
||||
# Test that context manager behaves properly
|
||||
@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(
|
||||
|
||||
def test_dataloader_break():
|
||||
accelerator = Accelerator()
|
||||
|
||||
first_dset = RegressionDataset(length=80)
|
||||
first_dataloader = DataLoader(first_dset, batch_size=16)
|
||||
second_dset = RegressionDataset(length=96)
|
||||
second_dataloader = DataLoader(second_dset, batch_size=16)
|
||||
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
|
||||
|
||||
assert accelerator.gradient_state.active_dataloader is None
|
||||
for iteration, _ in enumerate(first_dataloader):
|
||||
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
|
||||
@ -392,7 +392,6 @@ def main():
|
||||
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
|
||||
)
|
||||
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
|
||||
state.destroy_process_group()
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@ -53,7 +52,6 @@ from ..utils import (
|
||||
is_timm_available,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
is_torchvision_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformers_available,
|
||||
@ -431,18 +429,6 @@ def require_trackers(test_case):
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torchdata_stateful_dataloader(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchdata.stateful_dataloader.
|
||||
|
||||
These tests are skipped when torchdata with stateful_dataloader module isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(
|
||||
is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
|
||||
)(test_case)
|
||||
|
||||
|
||||
class TempDirTestCase(unittest.TestCase):
|
||||
"""
|
||||
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
||||
@ -671,19 +657,3 @@ def assert_exception(exception_class: Exception, msg: str = None) -> bool:
|
||||
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'"
|
||||
if was_ran:
|
||||
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.")
|
||||
|
||||
|
||||
def capture_call_output(func, *args, **kwargs):
|
||||
"""
|
||||
Takes in a `func` with `args` and `kwargs` and returns the captured stdout as a string
|
||||
"""
|
||||
captured_output = io.StringIO()
|
||||
original_stdout = sys.stdout
|
||||
try:
|
||||
sys.stdout = captured_output
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
sys.stdout = original_stdout
|
||||
return captured_output.getvalue()
|
||||
|
||||
@ -107,8 +107,6 @@ from .imports import (
|
||||
is_tensorboard_available,
|
||||
is_timm_available,
|
||||
is_torch_xla_available,
|
||||
is_torchdata_available,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
is_torchvision_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformers_available,
|
||||
@ -177,7 +175,7 @@ from .operations import (
|
||||
send_to_device,
|
||||
slice_tensors,
|
||||
)
|
||||
from .versions import compare_versions, is_torch_version
|
||||
from .versions import compare_versions, is_torch_version, parse
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
|
||||
@ -37,9 +37,7 @@ FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHA
|
||||
FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
|
||||
FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
|
||||
FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||
FSDP_PYTORCH_VERSION = (
|
||||
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
|
||||
)
|
||||
FSDP_PYTORCH_VERSION = "2.1.0"
|
||||
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
||||
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
|
||||
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
|
||||
|
||||
@ -313,9 +313,8 @@ class FP8RecipeKwargs(KwargsHandler):
|
||||
The margin to use for the gradient scaling.
|
||||
interval (`int`, *optional*, default to 1):
|
||||
The interval to use for how often the scaling factor is recomputed.
|
||||
fp8_format (`str`, *optional*, default to "HYBRID"):
|
||||
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
|
||||
`E4M3` for evaluation)
|
||||
fp8_format (`str`, *optional*, default to "E4M3"):
|
||||
The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
|
||||
amax_history_len (`int`, *optional*, default to 1024):
|
||||
The length of the history to use for the scaling factor computation
|
||||
amax_compute_algo (`str`, *optional*, default to "most_recent"):
|
||||
@ -365,7 +364,7 @@ class FP8RecipeKwargs(KwargsHandler):
|
||||
if self.interval is None:
|
||||
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
|
||||
if self.fp8_format is None:
|
||||
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
|
||||
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "E4M3")
|
||||
self.fp8_format = self.fp8_format.upper()
|
||||
if self.fp8_format not in get_args(FP8Format):
|
||||
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
|
||||
@ -749,7 +748,7 @@ class DataLoaderConfiguration:
|
||||
metadata={
|
||||
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
|
||||
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
|
||||
" underlying dataset is an `IterableDataset`, `False` otherwise."
|
||||
" underlying dataset is an `IterableDataslet`, `False` otherwise."
|
||||
},
|
||||
)
|
||||
even_batches: bool = field(
|
||||
@ -777,13 +776,6 @@ class DataLoaderConfiguration:
|
||||
" prepared dataloader has `pin_memory` set to `True` to work properly."
|
||||
},
|
||||
)
|
||||
use_stateful_dataloader: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
|
||||
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1345,11 +1337,10 @@ class FullyShardedDataParallelPlugin:
|
||||
},
|
||||
)
|
||||
sync_module_states: bool = field(
|
||||
default=None,
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 "
|
||||
"to ensure they are the same across all ranks after initialization. Defaults to `False` unless "
|
||||
"`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled."
|
||||
"to ensure they are the same across all ranks after initialization. Defaults to `True`"
|
||||
},
|
||||
)
|
||||
forward_prefetch: bool = field(
|
||||
@ -1498,9 +1489,9 @@ class FullyShardedDataParallelPlugin:
|
||||
# when using `sync_module_states`
|
||||
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
|
||||
|
||||
def set_state_dict_type(self, state_dict_type=None):
|
||||
def set_state_dict_type(self):
|
||||
"""
|
||||
Set the state dict config based on the `StateDictType`.
|
||||
Set the state dict config based on the `StateDictType.
|
||||
"""
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig,
|
||||
@ -1510,11 +1501,6 @@ class FullyShardedDataParallelPlugin:
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
# Override the state_dict_type if provided, typical use case:
|
||||
# user trains with sharded, but final save is with full
|
||||
if state_dict_type is not None:
|
||||
self.state_dict_type = state_dict_type
|
||||
|
||||
if self.state_dict_type is None:
|
||||
self.state_dict_type = os.environ.get("FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT")
|
||||
if isinstance(self.state_dict_type, str):
|
||||
@ -1543,7 +1529,9 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
# First base off of `_no_split_modules`
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []
|
||||
default_transformer_cls_names_to_wrap = (
|
||||
",".join(model._no_split_modules) if no_split_modules is not None else ""
|
||||
)
|
||||
if self.auto_wrap_policy == transformer_auto_wrap_policy:
|
||||
if self.transformer_cls_names_to_wrap is None:
|
||||
self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap
|
||||
|
||||
@ -19,11 +19,9 @@ import warnings
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from packaging.version import parse
|
||||
|
||||
from .environment import parse_flag_from_env, str_to_bool
|
||||
from .versions import compare_versions, is_torch_version
|
||||
from .versions import compare_versions, is_torch_version, parse
|
||||
|
||||
|
||||
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
|
||||
@ -178,7 +176,11 @@ def is_deepspeed_available():
|
||||
|
||||
|
||||
def is_pippy_available():
|
||||
return is_torch_version(">=", "2.4.0")
|
||||
package_exists = _is_package_available("pippy", "torchpippy")
|
||||
if package_exists:
|
||||
pippy_version = parse(importlib.metadata.version("torchpippy"))
|
||||
return compare_versions(pippy_version, ">", "0.1.1")
|
||||
return False
|
||||
|
||||
|
||||
def is_bf16_available(ignore_tpu=False):
|
||||
@ -195,7 +197,7 @@ def is_bf16_available(ignore_tpu=False):
|
||||
def is_4bit_bnb_available():
|
||||
package_exists = _is_package_available("bitsandbytes")
|
||||
if package_exists:
|
||||
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
||||
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
|
||||
return compare_versions(bnb_version, ">=", "0.39.0")
|
||||
return False
|
||||
|
||||
@ -203,7 +205,7 @@ def is_4bit_bnb_available():
|
||||
def is_8bit_bnb_available():
|
||||
package_exists = _is_package_available("bitsandbytes")
|
||||
if package_exists:
|
||||
bnb_version = version.parse(importlib.metadata.version("bitsandbytes"))
|
||||
bnb_version = parse(importlib.metadata.version("bitsandbytes"))
|
||||
return compare_versions(bnb_version, ">=", "0.37.2")
|
||||
return False
|
||||
|
||||
@ -251,7 +253,7 @@ def is_triton_available():
|
||||
def is_aim_available():
|
||||
package_exists = _is_package_available("aim")
|
||||
if package_exists:
|
||||
aim_version = version.parse(importlib.metadata.version("aim"))
|
||||
aim_version = parse(importlib.metadata.version("aim"))
|
||||
return compare_versions(aim_version, "<", "4.0.0")
|
||||
return False
|
||||
|
||||
@ -320,7 +322,7 @@ def is_mps_available(min_version="1.12"):
|
||||
|
||||
def is_ipex_available():
|
||||
def get_major_and_minor_from_version(full_version):
|
||||
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||
return str(parse(full_version).major) + "." + str(parse(full_version).minor)
|
||||
|
||||
_torch_version = importlib.metadata.version("torch")
|
||||
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||
@ -427,16 +429,3 @@ def is_xpu_available(check_device=False):
|
||||
|
||||
def is_dvclive_available():
|
||||
return _is_package_available("dvclive")
|
||||
|
||||
|
||||
def is_torchdata_available():
|
||||
return _is_package_available("torchdata")
|
||||
|
||||
|
||||
# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
|
||||
def is_torchdata_stateful_dataloader_available():
|
||||
package_exists = _is_package_available("torchdata")
|
||||
if package_exists:
|
||||
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
|
||||
return compare_versions(torchdata_version, ">=", "0.8.0")
|
||||
return False
|
||||
|
||||
@ -20,7 +20,8 @@ from .imports import is_fp8_available
|
||||
from .operations import GatheredParameters
|
||||
|
||||
|
||||
# Do not import `transformer_engine` at package level to avoid potential issues
|
||||
if is_fp8_available():
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
|
||||
def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
|
||||
@ -29,8 +30,6 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
|
||||
has_bias = module.bias is not None
|
||||
@ -88,8 +87,6 @@ def has_transformer_engine_layers(model):
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)):
|
||||
return True
|
||||
@ -101,8 +98,6 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
|
||||
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
|
||||
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
|
||||
from transformer_engine.pytorch import fp8_autocast
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -120,8 +115,7 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
|
||||
"""
|
||||
Applies FP8 context manager to the model's forward method
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
|
||||
# Import here to keep base imports fast
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
|
||||
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
||||
|
||||
@ -15,11 +15,20 @@
|
||||
import importlib.metadata
|
||||
from typing import Union
|
||||
|
||||
from packaging.version import Version, parse
|
||||
from packaging.version import Version
|
||||
from packaging.version import parse as _parse
|
||||
|
||||
from .constants import STR_OPERATION_TO_FUNC
|
||||
|
||||
|
||||
def parse(version: str):
|
||||
"""
|
||||
Same as `packaging.version.parse`, but grabs strictly the base version.
|
||||
"""
|
||||
version = _parse(version)
|
||||
return _parse(version.base_version)
|
||||
|
||||
|
||||
torch_version = parse(importlib.metadata.version("torch"))
|
||||
|
||||
|
||||
|
||||
@ -49,7 +49,6 @@ from accelerate.utils.other import patch_environment
|
||||
set_seed(42)
|
||||
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
LLAMA_TESTING = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
dtypes = [FP16, BF16]
|
||||
@ -136,49 +135,39 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
assert fsdp_plugin.state_dict_config.offload_to_cpu
|
||||
assert fsdp_plugin.state_dict_config.rank0_only
|
||||
|
||||
# We can also override the state_dict_type,
|
||||
# typical case: user trains with sharded, but final save is with full
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_type="FULL_STATE_DICT")
|
||||
fsdp_plugin.set_state_dict_type("SHARDED_STATE_DICT")
|
||||
assert fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT
|
||||
|
||||
def test_auto_wrap_policy(self):
|
||||
for model_name in [LLAMA_TESTING, BERT_BASE_CASED]:
|
||||
model = AutoModel.from_pretrained(model_name)
|
||||
layer_to_wrap = "LlamaDecoderLayer" if model_name == LLAMA_TESTING else "BertLayer"
|
||||
for policy in FSDP_AUTO_WRAP_POLICY:
|
||||
env = self.fsdp_env.copy()
|
||||
env["FSDP_AUTO_WRAP_POLICY"] = policy
|
||||
transformer_cls_to_wrap = None
|
||||
min_num_params = None
|
||||
env.pop("FSDP_TRANSFORMER_CLS_TO_WRAP", None)
|
||||
env.pop("FSDP_MIN_NUM_PARAMS", None)
|
||||
if policy == "TRANSFORMER_BASED_WRAP":
|
||||
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = layer_to_wrap
|
||||
transformer_cls_to_wrap = layer_to_wrap
|
||||
elif policy == "SIZE_BASED_WRAP":
|
||||
env["FSDP_MIN_NUM_PARAMS"] = "2000"
|
||||
min_num_params = 2000
|
||||
# First test via env
|
||||
with mockenv_context(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin()
|
||||
fsdp_plugin.set_auto_wrap_policy(model)
|
||||
if policy == "NO_WRAP":
|
||||
assert fsdp_plugin.auto_wrap_policy is None
|
||||
else:
|
||||
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
|
||||
|
||||
# Then manually set the policy
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=policy,
|
||||
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
|
||||
min_num_params=min_num_params,
|
||||
)
|
||||
model = AutoModel.from_pretrained(BERT_BASE_CASED)
|
||||
for policy in FSDP_AUTO_WRAP_POLICY:
|
||||
env = self.fsdp_env.copy()
|
||||
env["FSDP_AUTO_WRAP_POLICY"] = policy
|
||||
transformer_cls_to_wrap = None
|
||||
min_num_params = None
|
||||
if policy == "TRANSFORMER_BASED_WRAP":
|
||||
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer"
|
||||
transformer_cls_to_wrap = "BertLayer"
|
||||
elif policy == "SIZE_BASED_WRAP":
|
||||
env["FSDP_MIN_NUM_PARAMS"] = "2000"
|
||||
min_num_params = 2000
|
||||
# First test via env
|
||||
with mockenv_context(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin()
|
||||
fsdp_plugin.set_auto_wrap_policy(model)
|
||||
if policy == "NO_WRAP":
|
||||
assert fsdp_plugin.auto_wrap_policy is None
|
||||
else:
|
||||
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
|
||||
if policy == "NO_WRAP":
|
||||
assert fsdp_plugin.auto_wrap_policy is None
|
||||
else:
|
||||
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
|
||||
|
||||
# Then manually set the policy
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=policy,
|
||||
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
|
||||
min_num_params=min_num_params,
|
||||
)
|
||||
fsdp_plugin.set_auto_wrap_policy(model)
|
||||
if policy == "NO_WRAP":
|
||||
assert fsdp_plugin.auto_wrap_policy is None
|
||||
else:
|
||||
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
|
||||
|
||||
env = self.fsdp_env.copy()
|
||||
env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@ -27,7 +26,6 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import (
|
||||
require_bnb,
|
||||
@ -37,20 +35,9 @@ from accelerate.test_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from accelerate.test_utils.testing import (
|
||||
AccelerateTestCase,
|
||||
require_cuda,
|
||||
require_non_torch_xla,
|
||||
require_torchdata_stateful_dataloader,
|
||||
)
|
||||
from accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment
|
||||
from accelerate.utils.dataclasses import DataLoaderConfiguration
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
|
||||
from accelerate.utils import FP8RecipeKwargs, patch_environment
|
||||
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model
|
||||
from accelerate.utils.random import set_seed
|
||||
|
||||
|
||||
if is_torchdata_stateful_dataloader_available():
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
|
||||
class ModelWithTiedWeights(torch.nn.Module):
|
||||
@ -71,6 +58,7 @@ def create_components(tied_weights=False):
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
|
||||
train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))
|
||||
valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6])))
|
||||
|
||||
return model, optimizer, scheduler, train_dl, valid_dl
|
||||
|
||||
|
||||
@ -85,21 +73,6 @@ class ModelForTest(torch.nn.Module):
|
||||
return self.linear2(self.batchnorm(self.linear1(x)))
|
||||
|
||||
|
||||
def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0):
|
||||
"Generates a tuple of dummy DataLoaders to test with"
|
||||
|
||||
def get_dataset(n_batches):
|
||||
x = torch.randn(batch_size * n_batches, 3)
|
||||
y = torch.randn(batch_size * n_batches, 5)
|
||||
return TensorDataset(x, y)
|
||||
|
||||
train_dataset = get_dataset(n_train_batches)
|
||||
valid_dataset = get_dataset(n_valid_batches)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers)
|
||||
return (train_dataloader, valid_dataloader)
|
||||
|
||||
|
||||
def get_signature(model):
|
||||
return sum(param.abs().sum().item() for param in model.parameters())
|
||||
|
||||
@ -116,12 +89,7 @@ def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
# name, as by default it shows only the first param
|
||||
param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch"
|
||||
if len(param.args) > 1:
|
||||
param_based_name += "_tied_weights" if param.args[1] is True else ""
|
||||
if len(param.args) > 2:
|
||||
param_based_name += f"_num_workers_{param.args[2]}"
|
||||
if len(param.args) > 3:
|
||||
param_based_name += "_dispatch_batches" if param.args[3] is True else "_no_dispatch_batches"
|
||||
param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else ""
|
||||
return f"{func.__name__}_{param_based_name}"
|
||||
|
||||
|
||||
@ -647,156 +615,3 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
# check that pickle roundtrip works
|
||||
model_loaded = pickle.loads(pickle.dumps(model))
|
||||
model_loaded(inputs)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
def test_can_pickle_dataloader(self, dispatch_batches):
|
||||
"""
|
||||
Test that pickling a prepared dataloader works.
|
||||
"""
|
||||
data = torch.arange(10).to(torch_device)
|
||||
ds = torch.utils.data.TensorDataset(data)
|
||||
dl = torch.utils.data.DataLoader(ds)
|
||||
skip_dl = skip_first_batches(dl, 2)
|
||||
|
||||
# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
|
||||
# TODO: Add support for pickling StatefulDataLoader
|
||||
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
|
||||
accelerator = Accelerator(dataloader_config=dataloader_config)
|
||||
|
||||
original_dl, _ = accelerator.prepare(dl, skip_dl)
|
||||
if dispatch_batches:
|
||||
assert isinstance(original_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(original_dl, DataLoaderShard)
|
||||
|
||||
prepared_model_dumps = pickle.dumps(accelerator)
|
||||
|
||||
model_loaded = pickle.loads(prepared_model_dumps)
|
||||
assert len(model_loaded._dataloaders) == 2
|
||||
|
||||
# Assert equality of recovered and original dataloader
|
||||
loaded_dl = model_loaded._dataloaders[0]
|
||||
assert isinstance(loaded_dl, DataLoader)
|
||||
if dispatch_batches:
|
||||
assert isinstance(loaded_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(loaded_dl, DataLoaderShard)
|
||||
assert len(loaded_dl) == len(original_dl)
|
||||
assert [i for i in loaded_dl] == [i for i in original_dl]
|
||||
|
||||
# Test skip dataloader works as expected as well
|
||||
loaded_skip_dl = model_loaded._dataloaders[1]
|
||||
assert isinstance(loaded_skip_dl, DataLoader)
|
||||
if dispatch_batches:
|
||||
assert isinstance(loaded_dl, DataLoaderDispatcher)
|
||||
else:
|
||||
assert isinstance(loaded_dl, DataLoaderShard)
|
||||
assert len(loaded_skip_dl) == len(original_dl) - 2
|
||||
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]
|
||||
|
||||
# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
|
||||
"""Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object."""
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
accelerator = Accelerator(dataloader_config=dataloader_config)
|
||||
model, optimizer, scheduler, train_dl, valid_dl = create_components()
|
||||
|
||||
(
|
||||
prepared_model,
|
||||
prepared_optimizer,
|
||||
prepared_scheduler,
|
||||
prepared_train_dl,
|
||||
prepared_valid_dl,
|
||||
) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
|
||||
|
||||
assert prepared_model in accelerator._models
|
||||
assert prepared_optimizer in accelerator._optimizers
|
||||
assert prepared_scheduler in accelerator._schedulers
|
||||
assert prepared_train_dl in accelerator._dataloaders
|
||||
assert prepared_valid_dl in accelerator._dataloaders
|
||||
assert isinstance(prepared_train_dl, StatefulDataLoader)
|
||||
assert isinstance(prepared_valid_dl, StatefulDataLoader)
|
||||
|
||||
@parameterized.expand(
|
||||
itertools.product([True, False], [True, False], [0, 2], [True, False]),
|
||||
name_func=parameterized_custom_name_func,
|
||||
)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches):
|
||||
"""
|
||||
Test that saving and loading a model with a stateful dataloader returns the same model,
|
||||
and that the dataloader's iterator is restored properly."""
|
||||
set_seed(42)
|
||||
n_train_batches = 64 # Use enough batches to ensure we can get partial iterations on large compute
|
||||
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)
|
||||
accelerator = Accelerator(dataloader_config=dataloader_config)
|
||||
|
||||
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
|
||||
train_dl, valid_dl = create_dataloaders_for_test(n_train_batches=n_train_batches, num_workers=num_workers)
|
||||
model = ModelForTest()
|
||||
|
||||
(
|
||||
prepared_model,
|
||||
prepared_optimizer,
|
||||
prepared_scheduler,
|
||||
prepared_train_dl,
|
||||
prepared_valid_dl,
|
||||
) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
|
||||
|
||||
assert isinstance(prepared_train_dl, StatefulDataLoader)
|
||||
assert isinstance(prepared_valid_dl, StatefulDataLoader)
|
||||
|
||||
# Perform 3 training iterations to ensure the dataloader's iterator is advanced
|
||||
num_batches_to_skip = 3
|
||||
model.train()
|
||||
untrained_batches = []
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
for step, batch in enumerate(prepared_train_dl):
|
||||
x, y = batch
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
if step == num_batches_to_skip - 1:
|
||||
# Save the state once we've gone through a few batches
|
||||
accelerator.save_state(f"{tmpdirname}/state", safe_serialization=use_safetensors)
|
||||
if step >= num_batches_to_skip:
|
||||
untrained_batches.append(batch)
|
||||
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
# We then unwrap the trained model
|
||||
unwrapped_model = accelerator.unwrap_model(prepared_model)
|
||||
|
||||
original_linear1 = unwrapped_model.linear1.weight.clone()
|
||||
original_batchnorm = unwrapped_model.batchnorm.weight.clone()
|
||||
original_linear2 = unwrapped_model.linear2.weight.clone()
|
||||
|
||||
# Resume the state
|
||||
accelerator.load_state(f"{tmpdirname}/state")
|
||||
|
||||
# Train this to the end of the DataLoader
|
||||
batches_seen_with_loaded_dl = 0
|
||||
for batch in prepared_train_dl:
|
||||
x, y = batch
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
batches_seen_with_loaded_dl += 1
|
||||
|
||||
unwrapped_model_2 = accelerator.unwrap_model(prepared_model)
|
||||
|
||||
new_linear1 = unwrapped_model_2.linear1.weight
|
||||
new_batchnorm = unwrapped_model_2.batchnorm.weight
|
||||
new_linear2 = unwrapped_model_2.linear2.weight
|
||||
|
||||
# Assert equalities
|
||||
assert batches_seen_with_loaded_dl == len(not_skipped_batches)
|
||||
assert torch.allclose(original_linear1, new_linear1)
|
||||
assert torch.allclose(original_batchnorm, new_batchnorm)
|
||||
assert torch.allclose(original_linear2, new_linear2)
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@ -19,13 +20,13 @@ from unittest.mock import patch
|
||||
import torch
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
import accelerate.commands.test as accelerate_test_cmd
|
||||
from accelerate.commands.config.config_args import BaseConfig, ClusterConfig, SageMakerConfig, load_config_from_file
|
||||
from accelerate.commands.estimate import estimate_command, estimate_command_parser, gather_data
|
||||
from accelerate.commands.launch import _validate_launch_command, launch_command, launch_command_parser
|
||||
from accelerate.commands.tpu import tpu_command_launcher, tpu_command_parser
|
||||
from accelerate.commands.launch import _validate_launch_command, launch_command_parser
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from accelerate.test_utils.testing import (
|
||||
capture_call_output,
|
||||
DEFAULT_LAUNCH_COMMAND,
|
||||
get_launch_command,
|
||||
path_in_accelerate_package,
|
||||
require_multi_device,
|
||||
require_timm,
|
||||
@ -52,7 +53,6 @@ class AccelerateLauncherTester(unittest.TestCase):
|
||||
changed_path = config_folder / "_default_config.yaml"
|
||||
|
||||
test_config_path = Path("tests/test_configs")
|
||||
parser = launch_command_parser()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -65,11 +65,12 @@ class AccelerateLauncherTester(unittest.TestCase):
|
||||
cls.changed_path.rename(cls.config_path)
|
||||
|
||||
def test_no_config(self):
|
||||
args = ["--monitor_interval", "0.1", str(self.test_file_path)]
|
||||
if torch.cuda.is_available() and (torch.cuda.device_count() > 1):
|
||||
args = ["--multi_gpu"] + args
|
||||
args = self.parser.parse_args(["--monitor_interval", "0.1", str(self.test_file_path)])
|
||||
launch_command(args)
|
||||
cmd = get_launch_command(multi_gpu=True)
|
||||
else:
|
||||
cmd = DEFAULT_LAUNCH_COMMAND
|
||||
cmd.append(self.test_file_path)
|
||||
execute_subprocess_async(cmd, env=os.environ.copy())
|
||||
|
||||
def test_config_compatibility(self):
|
||||
invalid_configs = ["fp8", "invalid", "mpi", "sagemaker"]
|
||||
@ -77,21 +78,20 @@ class AccelerateLauncherTester(unittest.TestCase):
|
||||
if any(invalid_config in str(config) for invalid_config in invalid_configs):
|
||||
continue
|
||||
with self.subTest(config_file=config):
|
||||
args = self.parser.parse_args(["--config_file", str(config), str(self.test_file_path)])
|
||||
launch_command(args)
|
||||
cmd = get_launch_command(config_file=config) + [self.test_file_path]
|
||||
execute_subprocess_async(cmd)
|
||||
|
||||
def test_invalid_keys(self):
|
||||
config_path = self.test_config_path / "invalid_keys.yaml"
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
RuntimeError,
|
||||
msg="The config file at 'invalid_keys.yaml' had unknown keys ('another_invalid_key', 'invalid_key')",
|
||||
):
|
||||
args = self.parser.parse_args(["--config_file", str(config_path), str(self.test_file_path)])
|
||||
launch_command(args)
|
||||
cmd = get_launch_command(config_file=config_path) + [self.test_file_path]
|
||||
execute_subprocess_async(cmd)
|
||||
|
||||
def test_accelerate_test(self):
|
||||
args = accelerate_test_cmd.test_command_parser().parse_args([])
|
||||
accelerate_test_cmd.test_command(args)
|
||||
execute_subprocess_async(["accelerate", "test"])
|
||||
|
||||
@require_multi_device
|
||||
def test_notebook_launcher(self):
|
||||
@ -276,19 +276,18 @@ class TpuConfigTester(unittest.TestCase):
|
||||
command_file = "tests/test_samples/test_command_file.sh"
|
||||
gcloud = "Running gcloud compute tpus tpu-vm ssh"
|
||||
|
||||
def setUp(self):
|
||||
self.parser = tpu_command_parser()
|
||||
|
||||
def test_base(self):
|
||||
args = self.parser.parse_args(
|
||||
["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"]
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ ["--command", self.command, "--tpu_zone", self.tpu_zone, "--tpu_name", self.tpu_name, "--debug"],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
|
||||
|
||||
def test_base_backward_compatibility(self):
|
||||
args = self.parser.parse_args(
|
||||
[
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ [
|
||||
"--config_file",
|
||||
"tests/test_configs/0_12_0.yaml",
|
||||
"--command",
|
||||
@ -298,29 +297,31 @@ class TpuConfigTester(unittest.TestCase):
|
||||
"--tpu_name",
|
||||
self.tpu_name,
|
||||
"--debug",
|
||||
]
|
||||
],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
|
||||
|
||||
def test_with_config_file(self):
|
||||
args = self.parser.parse_args(["--config_file", "tests/test_configs/latest.yaml", "--debug"])
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
output = run_command(
|
||||
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--debug"], return_stdout=True
|
||||
)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
|
||||
in output
|
||||
)
|
||||
|
||||
def test_with_config_file_and_command(self):
|
||||
args = self.parser.parse_args(
|
||||
["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"]
|
||||
output = run_command(
|
||||
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--command", self.command, "--debug"],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert f"{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls --worker all" in output
|
||||
|
||||
def test_with_config_file_and_multiple_command(self):
|
||||
args = self.parser.parse_args(
|
||||
[
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ [
|
||||
"--config_file",
|
||||
"tests/test_configs/latest.yaml",
|
||||
"--command",
|
||||
@ -328,27 +329,29 @@ class TpuConfigTester(unittest.TestCase):
|
||||
"--command",
|
||||
'echo "Hello World"',
|
||||
"--debug",
|
||||
]
|
||||
],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; ls; echo "Hello World" --worker all'
|
||||
in output
|
||||
)
|
||||
|
||||
def test_with_config_file_and_command_file(self):
|
||||
args = self.parser.parse_args(
|
||||
["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"]
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ ["--config_file", "tests/test_configs/latest.yaml", "--command_file", self.command_file, "--debug"],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
|
||||
in output
|
||||
)
|
||||
|
||||
def test_with_config_file_and_command_file_backward_compatibility(self):
|
||||
args = self.parser.parse_args(
|
||||
[
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ [
|
||||
"--config_file",
|
||||
"tests/test_configs/0_12_0.yaml",
|
||||
"--command_file",
|
||||
@ -358,36 +361,37 @@ class TpuConfigTester(unittest.TestCase):
|
||||
"--tpu_name",
|
||||
self.tpu_name,
|
||||
"--debug",
|
||||
]
|
||||
],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; echo "hello world"; echo "this is a second command" --worker all'
|
||||
in output
|
||||
)
|
||||
|
||||
def test_accelerate_install(self):
|
||||
args = self.parser.parse_args(
|
||||
["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"]
|
||||
output = run_command(
|
||||
self.cmd + ["--config_file", "tests/test_configs/latest.yaml", "--install_accelerate", "--debug"],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate -U; echo "hello world"; echo "this is a second command" --worker all'
|
||||
in output
|
||||
)
|
||||
|
||||
def test_accelerate_install_version(self):
|
||||
args = self.parser.parse_args(
|
||||
[
|
||||
output = run_command(
|
||||
self.cmd
|
||||
+ [
|
||||
"--config_file",
|
||||
"tests/test_configs/latest.yaml",
|
||||
"--install_accelerate",
|
||||
"--accelerate_version",
|
||||
"12.0.0",
|
||||
"--debug",
|
||||
]
|
||||
],
|
||||
return_stdout=True,
|
||||
)
|
||||
output = capture_call_output(tpu_command_launcher, args)
|
||||
assert (
|
||||
f'{self.gcloud} test-tpu --zone us-central1-a --command {self.base_output}; pip install accelerate==12.0.0; echo "hello world"; echo "this is a second command" --worker all'
|
||||
in output
|
||||
|
||||
@ -15,39 +15,18 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
||||
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate import Accelerator
|
||||
from accelerate.data_loader import (
|
||||
BatchSamplerShard,
|
||||
DataLoaderDispatcher,
|
||||
DataLoaderShard,
|
||||
DataLoaderStateMixin,
|
||||
IterableDatasetShard,
|
||||
SkipBatchSampler,
|
||||
SkipDataLoader,
|
||||
prepare_data_loader,
|
||||
skip_first_batches,
|
||||
)
|
||||
from accelerate.state import GradientState
|
||||
from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
|
||||
from accelerate.utils import is_torchdata_stateful_dataloader_available
|
||||
|
||||
|
||||
if is_torchdata_stateful_dataloader_available():
|
||||
from torchdata.stateful_dataloader import (
|
||||
StatefulDataLoader,
|
||||
)
|
||||
|
||||
|
||||
def parameterized_custom_name_func(func, param_num, param):
|
||||
# customize the test name generator function as we want both params to appear in the sub-test
|
||||
# name, as by default it shows only the first param
|
||||
param_based_name = f"num_workers_{param.args[0]}"
|
||||
return f"{func.__name__}_{param_based_name}"
|
||||
|
||||
|
||||
class RandomIterableDataset(IterableDataset):
|
||||
@ -65,21 +44,6 @@ class RandomIterableDataset(IterableDataset):
|
||||
stop = random.random() < self.p_stop
|
||||
|
||||
|
||||
class SimpleIterableDataset(IterableDataset):
|
||||
def __init__(self, num_samples=1000):
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.num_samples):
|
||||
yield torch.rand(1)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
class DataLoaderTester(unittest.TestCase):
|
||||
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
|
||||
batch_sampler_shards = [
|
||||
@ -400,48 +364,11 @@ class DataLoaderTester(unittest.TestCase):
|
||||
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)
|
||||
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)
|
||||
|
||||
def test_iterable_dataset_using_none_batch_size(self):
|
||||
dataset = SimpleIterableDataset(100)
|
||||
dataloader = DataLoader(dataset, batch_size=None)
|
||||
dataloader = prepare_data_loader(dataloader)
|
||||
for d in dataloader:
|
||||
assert isinstance(d, torch.Tensor)
|
||||
|
||||
def test_skip_batch_sampler(self):
|
||||
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
|
||||
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
|
||||
assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]]
|
||||
|
||||
def test_dataloader_inheritance(self):
|
||||
"""
|
||||
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
|
||||
are instances of DataLoader and DataLoaderStateMixin.
|
||||
"""
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
|
||||
|
||||
# Test dataloaders are instances of instantiated classes
|
||||
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
|
||||
assert isinstance(skip_dl, SkipDataLoader)
|
||||
assert isinstance(dl_shard, DataLoaderShard)
|
||||
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
|
||||
|
||||
# Test dataloaders are instances of base classes
|
||||
assert isinstance(skip_dl, DataLoader)
|
||||
assert isinstance(dl_shard, DataLoader)
|
||||
assert isinstance(dl_dispatcher, DataLoader)
|
||||
|
||||
assert isinstance(dl_shard, DataLoaderStateMixin)
|
||||
assert isinstance(dl_dispatcher, DataLoaderStateMixin)
|
||||
|
||||
assert isinstance(skip_dl.base_dataloader, DataLoader)
|
||||
assert isinstance(dl_shard.base_dataloader, DataLoader)
|
||||
assert isinstance(dl_dispatcher.base_dataloader, DataLoader)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = DataLoaderShard.base_dataloader
|
||||
|
||||
def test_skip_data_loader(self):
|
||||
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2)
|
||||
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
|
||||
@ -461,6 +388,7 @@ class DataLoaderTester(unittest.TestCase):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
def test_end_of_dataloader_dispatcher(self):
|
||||
Accelerator()
|
||||
dataloader = DataLoaderDispatcher(range(16), batch_size=4)
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
@ -468,342 +396,3 @@ class DataLoaderTester(unittest.TestCase):
|
||||
# Test it also works on the second iteration
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
|
||||
class StatefulDataLoaderTester(unittest.TestCase):
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_skip_data_loader(self):
|
||||
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
|
||||
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_end_of_dataloader(self):
|
||||
dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)
|
||||
assert dataloader.use_stateful_dataloader
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
# Test it also works on the second iteration
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_end_of_dataloader_dispatcher(self):
|
||||
dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
# Test it also works on the second iteration
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_dataloader_state_dict(self, num_workers):
|
||||
"""
|
||||
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
|
||||
"""
|
||||
dataset = list(range(16))
|
||||
dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
|
||||
|
||||
assert dataloader.use_stateful_dataloader
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
vals = []
|
||||
for idx, val in enumerate(dataloader):
|
||||
vals.append(val)
|
||||
if idx == 1:
|
||||
sd = dataloader.state_dict()
|
||||
assert len(vals) == 4
|
||||
|
||||
dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
|
||||
dataloader2.load_state_dict(sd)
|
||||
|
||||
data1 = vals[2:]
|
||||
data2 = list(dataloader2)
|
||||
assert len(data1) == len(data2)
|
||||
for d1, d2 in zip(data1, data2):
|
||||
assert torch.allclose(d1, d2)
|
||||
|
||||
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_dataloader_dispatcher_state_dict(self, num_workers):
|
||||
"""
|
||||
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
|
||||
"""
|
||||
dataset = list(range(16))
|
||||
dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
|
||||
|
||||
assert dataloader.use_stateful_dataloader
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
vals = []
|
||||
for idx, val in enumerate(dataloader):
|
||||
vals.append(val)
|
||||
if idx == 1:
|
||||
sd = dataloader.state_dict()
|
||||
assert len(vals) == 4
|
||||
dataloader2 = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers
|
||||
)
|
||||
dataloader2.load_state_dict(sd)
|
||||
|
||||
data1 = vals[2:]
|
||||
data2 = list(dataloader2)
|
||||
assert len(data1) == len(data2)
|
||||
for d1, d2 in zip(data1, data2):
|
||||
assert torch.allclose(d1, d2)
|
||||
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_dataloader_inheritance(self):
|
||||
"""
|
||||
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
|
||||
subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
|
||||
"""
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
|
||||
# Test dataloaders are instances of instantiated classes
|
||||
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
|
||||
assert isinstance(skip_dl, SkipDataLoader)
|
||||
assert isinstance(dl_shard, DataLoaderShard)
|
||||
assert isinstance(dl_dispatcher, DataLoaderDispatcher)
|
||||
|
||||
assert isinstance(skip_dl, StatefulDataLoader)
|
||||
assert isinstance(dl_shard, StatefulDataLoader)
|
||||
assert isinstance(dl_dispatcher, StatefulDataLoader)
|
||||
|
||||
assert isinstance(dl_shard, DataLoaderStateMixin)
|
||||
assert isinstance(dl_dispatcher, DataLoaderStateMixin)
|
||||
|
||||
assert isinstance(skip_dl.base_dataloader, StatefulDataLoader)
|
||||
assert isinstance(dl_shard.base_dataloader, StatefulDataLoader)
|
||||
assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader)
|
||||
|
||||
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
|
||||
"""
|
||||
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
|
||||
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`.
|
||||
"""
|
||||
dataset = list(range(64))
|
||||
|
||||
# Set the seed for reproducibility
|
||||
def g():
|
||||
return torch.Generator().manual_seed(42)
|
||||
|
||||
accelerator = Accelerator()
|
||||
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
|
||||
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
|
||||
|
||||
num_batches_to_skip = 8
|
||||
|
||||
def get_first_n_batches(dl, n, device):
|
||||
"""
|
||||
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
|
||||
"""
|
||||
batches = []
|
||||
for idx, batch in enumerate(dl):
|
||||
if idx == n - 1:
|
||||
if hasattr(dl, "end"):
|
||||
dl.end()
|
||||
break
|
||||
batches.append(batch.to(device))
|
||||
return batches
|
||||
|
||||
# Iterate over all of the dataloaders identically, expect the same values
|
||||
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device)
|
||||
batches_from_dataloaders = [
|
||||
get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test
|
||||
]
|
||||
|
||||
for dl_batches in batches_from_dataloaders:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected, actual)
|
||||
|
||||
# The adapters should all produce the same state_dict as the reference stateful dataloader
|
||||
expected_state_dict = stateful_dl.state_dict()
|
||||
skip_dl_state_dict = skip_dl.state_dict()
|
||||
dl_shard_state_dict = dl_shard.state_dict()
|
||||
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
|
||||
|
||||
assert expected_state_dict == skip_dl_state_dict
|
||||
assert expected_state_dict == dl_shard_state_dict
|
||||
assert expected_state_dict == dl_dispatcher_state_dict
|
||||
|
||||
# Load the state dict into new dataloaders
|
||||
manual_skip_dl = SkipDataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
num_workers=num_workers,
|
||||
generator=g(),
|
||||
skip_batches=num_batches_to_skip,
|
||||
use_stateful_dataloader=True,
|
||||
)
|
||||
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
loaded_stateful_dl.load_state_dict(expected_state_dict)
|
||||
loaded_skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_skip_dl.load_state_dict(expected_state_dict)
|
||||
loaded_dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_shard.load_state_dict(expected_state_dict)
|
||||
loaded_dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
|
||||
|
||||
# Continue the iteration, expecting identical behavior across the board
|
||||
def get_all_batches(dl, device):
|
||||
"""
|
||||
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
|
||||
"""
|
||||
batches = []
|
||||
num_batches_yielded = 0
|
||||
for batch in dl:
|
||||
batches.append(batch.to(device))
|
||||
num_batches_yielded += 1
|
||||
return (batches, num_batches_yielded)
|
||||
|
||||
expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device)
|
||||
dataloader_batch_results = [
|
||||
get_all_batches(dl, accelerator.device)
|
||||
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
|
||||
]
|
||||
for dl_results in dataloader_batch_results:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected[0], actual[0])
|
||||
assert expected_batch_results[1] == dl_results[1]
|
||||
|
||||
assert accelerator.gradient_state.active_dataloader is None
|
||||
|
||||
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
|
||||
"""
|
||||
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
|
||||
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using
|
||||
Accelerator (and instead using the decoupled `PartialState` workflow).
|
||||
"""
|
||||
dataset = list(range(64))
|
||||
|
||||
# Set the seed for reproducibility
|
||||
def g():
|
||||
return torch.Generator().manual_seed(42)
|
||||
|
||||
state = PartialState()
|
||||
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
|
||||
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
|
||||
|
||||
num_batches_to_skip = 8
|
||||
|
||||
def get_first_n_batches(dl, n, device):
|
||||
"""
|
||||
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
|
||||
"""
|
||||
batches = []
|
||||
for idx, batch in enumerate(dl):
|
||||
if idx == n - 1:
|
||||
if hasattr(dl, "end"):
|
||||
dl.end()
|
||||
break
|
||||
batches.append(batch.to(device))
|
||||
return batches
|
||||
|
||||
# Iterate over all of the dataloaders identically, expect the same values
|
||||
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)
|
||||
batches_from_dataloaders = [
|
||||
get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test
|
||||
]
|
||||
|
||||
for dl_batches in batches_from_dataloaders:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected, actual)
|
||||
|
||||
# The adapters should all produce the same state_dict as the reference stateful dataloader
|
||||
expected_state_dict = stateful_dl.state_dict()
|
||||
skip_dl_state_dict = skip_dl.state_dict()
|
||||
dl_shard_state_dict = dl_shard.state_dict()
|
||||
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
|
||||
|
||||
assert expected_state_dict == skip_dl_state_dict
|
||||
assert expected_state_dict == dl_shard_state_dict
|
||||
assert expected_state_dict == dl_dispatcher_state_dict
|
||||
|
||||
# Load the state dict into new dataloaders
|
||||
manual_skip_dl = SkipDataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
num_workers=num_workers,
|
||||
generator=g(),
|
||||
skip_batches=num_batches_to_skip,
|
||||
use_stateful_dataloader=True,
|
||||
)
|
||||
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
loaded_stateful_dl.load_state_dict(expected_state_dict)
|
||||
loaded_skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_skip_dl.load_state_dict(expected_state_dict)
|
||||
loaded_dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_shard.load_state_dict(expected_state_dict)
|
||||
loaded_dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
|
||||
|
||||
# Continue the iteration, expecting identical behavior across the board
|
||||
def get_all_batches(dl, device):
|
||||
"""
|
||||
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
|
||||
"""
|
||||
batches = []
|
||||
num_batches_yielded = 0
|
||||
for batch in dl:
|
||||
batches.append(batch.to(device))
|
||||
num_batches_yielded += 1
|
||||
return (batches, num_batches_yielded)
|
||||
|
||||
expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)
|
||||
dataloader_batch_results = [
|
||||
get_all_batches(dl, state.device)
|
||||
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
|
||||
]
|
||||
for dl_results in dataloader_batch_results:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected[0], actual[0])
|
||||
assert expected_batch_results[1] == dl_results[1]
|
||||
|
||||
# Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with
|
||||
# default parameters) by `DataLoaderDispatcher`
|
||||
assert GradientState._shared_state != {}, "GradientState should already be initialized!"
|
||||
|
||||
gradient_state = GradientState()
|
||||
assert gradient_state.active_dataloader is None
|
||||
|
||||
@ -19,7 +19,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock, skip
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
@ -45,13 +45,11 @@ from accelerate.utils import write_basic_config
|
||||
|
||||
EXCLUDE_EXAMPLES = [
|
||||
"cross_validation.py",
|
||||
"checkpointing.py",
|
||||
"gradient_accumulation.py",
|
||||
"local_sgd.py",
|
||||
"multi_process_metrics.py",
|
||||
"memory.py",
|
||||
"schedule_free.py",
|
||||
"tracking.py",
|
||||
"automatic_gradient_accumulation.py",
|
||||
"fsdp_with_peak_mem_tracking.py",
|
||||
"deepspeed_with_config_support.py",
|
||||
@ -261,9 +259,6 @@ class FeatureExamplesTests(TempDirTestCase):
|
||||
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@skip(
|
||||
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
|
||||
)
|
||||
@require_multi_device
|
||||
def test_distributed_inference_examples_stable_diffusion(self):
|
||||
testargs = ["examples/inference/distributed/stable_diffusion.py"]
|
||||
|
||||
@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from accelerate.test_utils import require_transformer_engine
|
||||
from accelerate.test_utils.testing import TempDirTestCase, require_import_timer
|
||||
from accelerate.utils import is_import_timer_available
|
||||
|
||||
@ -33,7 +31,7 @@ def convert_list_to_string(data):
|
||||
|
||||
|
||||
def run_import_time(command: str):
|
||||
output = subprocess.run([sys.executable, "-X", "importtime", "-c", command], capture_output=True, text=True)
|
||||
output = subprocess.run(["python3", "-X", "importtime", "-c", command], capture_output=True, text=True)
|
||||
return output.stderr
|
||||
|
||||
|
||||
@ -83,18 +81,3 @@ class ImportSpeedTester(TempDirTestCase):
|
||||
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
|
||||
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
|
||||
self.assertLess(pct_more, 20, err_msg)
|
||||
|
||||
|
||||
@require_transformer_engine
|
||||
class LazyImportTester(TempDirTestCase):
|
||||
"""
|
||||
Test suite which checks if specific packages are lazy-loaded.
|
||||
|
||||
Eager-import will trigger circular import in some case,
|
||||
e.g. in huggingface/accelerate#3056.
|
||||
"""
|
||||
|
||||
def test_te_import(self):
|
||||
output = run_import_time("import accelerate, accelerate.utils.transformer_engine")
|
||||
|
||||
self.assertFalse(" transformer_engine" in output, "`transformer_engine` should not be imported on import")
|
||||
|
||||
@ -14,9 +14,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
from accelerate import debug_launcher
|
||||
from accelerate.test_utils import (
|
||||
DEFAULT_LAUNCH_COMMAND,
|
||||
@ -32,7 +29,6 @@ from accelerate.utils import patch_environment
|
||||
|
||||
|
||||
@require_huggingface_suite
|
||||
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "Test requires numpy version < 2.0")
|
||||
class MetricTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_metrics.py")
|
||||
|
||||
@ -27,7 +27,6 @@ from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
# We use TF to parse the logs
|
||||
from accelerate import Accelerator
|
||||
@ -69,7 +68,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@require_tensorboard
|
||||
class TensorBoardTrackingTest(unittest.TestCase):
|
||||
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "TB doesn't support numpy 2.0")
|
||||
def test_init_trackers(self):
|
||||
project_name = "test_project_with_config"
|
||||
with tempfile.TemporaryDirectory() as dirpath:
|
||||
|
||||
@ -49,6 +49,7 @@ from accelerate.utils import (
|
||||
listify,
|
||||
pad_across_processes,
|
||||
pad_input_tensors,
|
||||
parse,
|
||||
patch_environment,
|
||||
recursively_apply,
|
||||
save,
|
||||
@ -411,3 +412,8 @@ class UtilsTester(unittest.TestCase):
|
||||
tqdm(True, range(3), disable=True)
|
||||
assert "Passing `True` as the first argument to" in cm.pop().message.args[0]
|
||||
tqdm(range(3), main_process_only=True, disable=True)
|
||||
|
||||
def test_dev0_parsing(self):
|
||||
v1 = parse("0.34.0.dev0")
|
||||
v2 = parse("0.34.0")
|
||||
assert v1 == v2
|
||||
|
||||
Reference in New Issue
Block a user