mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 16:44:39 +08:00
Compare commits
73 Commits
context-pa
...
feat/async
| Author | SHA1 | Date | |
|---|---|---|---|
| 50042518db | |||
| 571ca0200d | |||
| 0cb1a33475 | |||
| dfdc219018 | |||
| 45959d7b96 | |||
| 8b493524c8 | |||
| 9ead94e556 | |||
| a0bc36e8ed | |||
| 8830e58a91 | |||
| 40ebb4bea3 | |||
| ec92b1af7a | |||
| 62ede1ed2a | |||
| 9f9c490c6b | |||
| 8b55e62b2c | |||
| 0e4419b347 | |||
| 3b67c21696 | |||
| 7b981788ca | |||
| c4460e33ef | |||
| 5dd3d0b690 | |||
| 5fe4460ccd | |||
| 979d81e4a9 | |||
| 7c25f696b8 | |||
| a7d6f28f99 | |||
| 23cf4ef8a3 | |||
| ff872f5f71 | |||
| 2941a6b0fb | |||
| c0a3aefea8 | |||
| 42fdda1c1f | |||
| e23b004b30 | |||
| 898cad39e8 | |||
| 24c8157bba | |||
| 6891c57072 | |||
| 24e48f3d20 | |||
| 6640ff415c | |||
| c173b4fdd6 | |||
| cb343c63d7 | |||
| 354b0b5da3 | |||
| 9359a0194f | |||
| 2f075c724c | |||
| 7ecc2d7f39 | |||
| 12f89bb754 | |||
| 348aabaaaf | |||
| 3b13453bbf | |||
| 0408ab12d7 | |||
| 55e518a762 | |||
| 7e11ac43f0 | |||
| e2cc537db8 | |||
| 847ae58c74 | |||
| 6e104f31de | |||
| 524e5f9828 | |||
| d6c986c3f2 | |||
| 1ac8643df7 | |||
| 07ce74868c | |||
| 175fe91589 | |||
| fe16ce8bce | |||
| 5987d79a53 | |||
| 31af8d4e8e | |||
| b7493a82b1 | |||
| a16d2bb3c1 | |||
| cac22ed980 | |||
| be826a6b7b | |||
| 5939640829 | |||
| 7f9c8cbe34 | |||
| 9888c7ed23 | |||
| 42a68c30dc | |||
| 6597dae780 | |||
| 8878d93745 | |||
| 2eaf5cdbbc | |||
| 23c1d8db89 | |||
| 0af621bbec | |||
| bee04f1b01 | |||
| 8a953f08c6 | |||
| 3518c03584 |
@ -15,7 +15,7 @@ jobs:
|
||||
outputs:
|
||||
version: ${{ steps.step1.outputs.version }}
|
||||
steps:
|
||||
- uses: actions/checkout@4
|
||||
- uses: actions/checkout@v4
|
||||
- id: step1
|
||||
run: echo "version=$(python setup.py --version)" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
17
.github/workflows/gaudi3_scheduled.yml
vendored
17
.github/workflows/gaudi3_scheduled.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
group: itac-bm-emr-gaudi3-dell-2gaudi
|
||||
|
||||
container:
|
||||
image: docker://vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
|
||||
image: docker://vault.habana.ai/gaudi-docker/1.21.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
|
||||
options: --runtime=habana --shm-size=64G --cap-add=sys_nice --env HABANA_VISIBLE_DEVICES
|
||||
env:
|
||||
OMPI_MCA_btl_vader_single_copy_mechanism: none
|
||||
@ -66,16 +66,21 @@ jobs:
|
||||
run: |
|
||||
make test_big_modeling
|
||||
|
||||
- name: Run FSDP integration tests
|
||||
if: ${{ !cancelled() && (success() || failure()) }}
|
||||
run: |
|
||||
make test_fsdp
|
||||
|
||||
- name: Run DeepSpeed integration tests
|
||||
if: ${{ !cancelled() && (success() || failure()) }}
|
||||
run: |
|
||||
make test_deepspeed
|
||||
|
||||
- name: Run FSDP integration tests
|
||||
if: ${{ !cancelled() && (success() || failure()) }}
|
||||
run: |
|
||||
make test_fsdp
|
||||
|
||||
- name: Run TP integration tests
|
||||
if: ${{ !cancelled() && (success() || failure()) }}
|
||||
run: |
|
||||
make test_tp
|
||||
|
||||
- name: Run Examples tests
|
||||
if: ${{ !cancelled() && (success() || failure()) }}
|
||||
run: |
|
||||
|
||||
@ -112,7 +112,7 @@ jobs:
|
||||
cd skorch;
|
||||
git config --global --add safe.directory '*'
|
||||
git checkout master && git pull
|
||||
pip install .[testing]
|
||||
pip install .[test]
|
||||
pip install flaky
|
||||
|
||||
- name: Show installed libraries
|
||||
|
||||
19
Makefile
19
Makefile
@ -23,16 +23,23 @@ style:
|
||||
doc-builder style src/accelerate docs/source --max_len 119
|
||||
|
||||
# Run tests for the library
|
||||
test_big_modeling:
|
||||
python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
|
||||
|
||||
test_core:
|
||||
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py \
|
||||
--ignore=./tests/fsdp --ignore=./tests/tp --ignore=./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
|
||||
python -m pytest -s -v ./tests/ \
|
||||
--ignore=./tests/test_big_modeling.py \
|
||||
--ignore=./tests/test_modeling_utils.py \
|
||||
--ignore=./tests/test_examples.py \
|
||||
--ignore=./tests/test_cli.py \
|
||||
--ignore=./tests/deepspeed \
|
||||
--ignore=./tests/fsdp \
|
||||
--ignore=./tests/tp \
|
||||
$(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
|
||||
|
||||
test_cli:
|
||||
python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
|
||||
|
||||
test_big_modeling:
|
||||
python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
|
||||
|
||||
test_deepspeed:
|
||||
python -m pytest -s -v ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_deepspeed.log",)
|
||||
|
||||
@ -57,7 +64,7 @@ test_examples:
|
||||
|
||||
# Broken down example tests for the CI runners
|
||||
test_integrations:
|
||||
python -m pytest -s -v ./tests/deepspeed ./tests/fsdp ./tests/tp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
|
||||
python -m pytest -s -v ./tests/fsdp ./tests/tp ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
|
||||
|
||||
test_example_differences:
|
||||
python -m pytest -s -v ./tests/test_examples.py::ExampleDifferenceTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_example_diff.log",)
|
||||
|
||||
@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
|
||||
RUN git clone https://github.com/huggingface/accelerate.git
|
||||
|
||||
RUN cd accelerate && \
|
||||
pip install -e . && \
|
||||
pip install -e .[deepspeed] && \
|
||||
cd benchmarks/fp8
|
||||
|
||||
RUN /bin/bash
|
||||
|
||||
@ -62,8 +62,8 @@
|
||||
title: Amazon SageMaker
|
||||
- local: usage_guides/mps
|
||||
title: Apple M1 GPUs
|
||||
- local: usage_guides/ipex
|
||||
title: IPEX training with CPU
|
||||
- local: usage_guides/intel_cpu
|
||||
title: Intel CPU
|
||||
- local: usage_guides/gaudi
|
||||
title: Intel Gaudi
|
||||
- local: usage_guides/compilation
|
||||
@ -82,8 +82,6 @@
|
||||
title: Accelerate's internal mechanism
|
||||
- local: concept_guides/big_model_inference
|
||||
title: Loading big models into memory
|
||||
- local: concept_guides/context_parallel
|
||||
title: Context parallelism
|
||||
- local: concept_guides/performance
|
||||
title: Comparing performance across distributed setups
|
||||
- local: concept_guides/deferring_execution
|
||||
@ -94,6 +92,8 @@
|
||||
title: FSDP vs DeepSpeed
|
||||
- local: concept_guides/fsdp1_vs_fsdp2
|
||||
title: FSDP1 vs FSDP2
|
||||
- local: concept_guides/context_parallelism
|
||||
title: Context parallelism
|
||||
- local: concept_guides/low_precision_training
|
||||
title: Low precision training methods
|
||||
- local: concept_guides/training_tpu
|
||||
|
||||
@ -19,54 +19,70 @@ This guide will cover basics of using context parallelism in 🤗`accelerate`, f
|
||||
|
||||
## Why context parallelism?
|
||||
|
||||
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
|
||||
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.
|
||||
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
|
||||
|
||||
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
|
||||
|
||||
|
||||
## How to use context parallelism?
|
||||
|
||||
As with any other feature in 🤗`accelerate`, enabling context parallelism is as simple as passing the corresponding flags to `accelerate launch`.
|
||||
```diff
|
||||
from accelerate.utils import ParallelismConfig, TorchContextParallelConfig
|
||||
|
||||
+ cp_config = TorchContextParallelConfig(
|
||||
+ cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
|
||||
+ )
|
||||
|
||||
+ parallelism_config = ParallelismConfig(
|
||||
+ cp_size=8,
|
||||
+ cp_handler=cp_config, # or just cp_size=8, if you want to use the default "allgather"
|
||||
+ )
|
||||
|
||||
accelerator = Accelerator(
|
||||
...,
|
||||
parallelism_config=parallelism_config,
|
||||
)
|
||||
```
|
||||
|
||||
As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
|
||||
In this case, it's no different:
|
||||
|
||||
```bash
|
||||
accelerate launch --context-parallel-size 8 --context-parallel-shard-rotation [allgather|alltoall] ...
|
||||
accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...
|
||||
```
|
||||
|
||||
Context parallelism is tightly coupled (for now) with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism is applied only if `FSDP2` is enabled.
|
||||
You can also enable context parallelism programatically, by passing it in the `FullyShardedDataParallelPlugin` constructor:
|
||||
> [!Tip]
|
||||
> You can also set the `cp_size` and `cp_comm_strategy` in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.
|
||||
|
||||
```diff
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
> [!Tip]
|
||||
> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.
|
||||
> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.
|
||||
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
...
|
||||
fsdp_version=2,
|
||||
+ cp_size=8,
|
||||
+ cp_comm_strategy="allgather",
|
||||
)
|
||||
accelerator = Accelerator(fsdp_plugin=plugin)
|
||||
```
|
||||
> [!Warning]
|
||||
> Context parallelism is tightly coupled with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your
|
||||
> program. If no `FSDP2` is used, error will be raised.
|
||||
|
||||
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later).
|
||||
> [!Warning]
|
||||
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
|
||||
|
||||
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
|
||||
You can use it as follows:
|
||||
|
||||
```python
|
||||
for batch in dataloader:
|
||||
with accelerator.context_parallel(
|
||||
with accelerator.maybe_context_parallel(
|
||||
buffers=[batch["input_ids"], batch["attention_mask"]],
|
||||
buffer_seq_dims=[1, 1],
|
||||
no_restore_buffers={batch["input_ids"]},
|
||||
no_restore_buffers={batch["input_ids"], batch["labels"]},
|
||||
):
|
||||
outputs = model(batch)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
> [!Warning]
|
||||
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
|
||||
|
||||
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
|
||||
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
|
||||
@ -75,7 +91,10 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
|
||||
</p>
|
||||
|
||||
> [!Tip]
|
||||
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py). For instructions on how to run it, see the [README](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/README.md) in the same folder.
|
||||
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
|
||||
> ```bash
|
||||
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
|
||||
> ```
|
||||
|
||||
|
||||
## Accelerate's interface
|
||||
@ -83,19 +102,32 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
|
||||
The context manager takes a few arguments, that are used to configure the context parallelism.
|
||||
|
||||
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
|
||||
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
|
||||
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is reccomended to pass the same arguments as to the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.
|
||||
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.
|
||||
as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.
|
||||
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.
|
||||
|
||||
|
||||
> [!Warning]
|
||||
> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.
|
||||
> Imagine this case:
|
||||
> labels = [l1, l2, l3, l4, ... li]
|
||||
> if we apply context parallelism, each rank would end up with a part of labels, such as this:
|
||||
> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...
|
||||
> after transformers modelling code shifts the labels, it would end up with:
|
||||
> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...
|
||||
> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.
|
||||
> Because of this, you need to manually shift the labels before passing them in the model
|
||||
|
||||
|
||||
## Configurable options
|
||||
Accelerate provides only a few options to configure context parallelism, which are:
|
||||
Accelerate provides only a single option to configure context parallelism (except for `cp_size`)
|
||||
|
||||
- `cp_size`: The number of ranks to shard the inputs to the attention computation across the sequence dimension.
|
||||
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly reccomend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
|
||||
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
|
||||
|
||||
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
|
||||
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
|
||||
|
||||
You can see an end-to-end example in the [FSDP2 context parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py) file, where you can train an 8B model with 128k sequence length on 8x H100 SXM GPUs. Using multi-node training, you can scale this to 1M+ sequence length on 64x H100 SXM GPUs.
|
||||
You can see an end-to-end example in the [ND parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py) file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.
|
||||
|
||||
## Technical details
|
||||
|
||||
@ -110,7 +142,7 @@ We're going to be using word `shard` extensively in the following sections, so l
|
||||
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
|
||||
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
|
||||
|
||||
We can formalize this in a following pseudocode:
|
||||
We can formalize this in the following pseudocode:
|
||||
```python
|
||||
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
|
||||
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
|
||||
@ -132,7 +164,7 @@ In ideal scenario, all-gather finishes in the exact moment as the calculation of
|
||||
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
|
||||
|
||||
## How to choose the right rotation method?
|
||||
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
|
||||
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
|
||||
|
||||
You can directly see this issue in the profiler output in the image below:
|
||||
<p align="center">
|
||||
@ -144,8 +176,10 @@ You can directly see this issue in the profiler output in the image below:
|
||||
|
||||
## Why only FSDP2?
|
||||
|
||||
We only support context parallelism with `FSDP2` for now, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
|
||||
utilize its full potential. In the profiler output in the image below, you can see why this is the case.
|
||||
We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
|
||||
utilize its full potential.
|
||||
How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.
|
||||
This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
|
||||
@ -154,3 +188,17 @@ utilize its full potential. In the profiler output in the image below, you can s
|
||||
</p>
|
||||
|
||||
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.
|
||||
|
||||
## Data dispatching in joint mesh
|
||||
|
||||
We make sure to dispatch the same batch of data to the whole `cp` subgroup, so that the results are correct. (Meaning each rank in `cp` subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of `dp_shard` group.
|
||||
Imagine it like this:
|
||||
```
|
||||
# 8 GPUS, --dp_shard_size 4, --cp_size 2
|
||||
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
||||
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
|
||||
# GPUs 0,1 = batch 0
|
||||
# GPUs 2,3 = batch 1
|
||||
... and so on.
|
||||
```
|
||||
|
||||
@ -139,7 +139,7 @@ values. They can also be passed in manually.
|
||||
* `--cpu` (`bool`) -- Whether or not to force the training on the CPU.
|
||||
* `--multi_gpu` (`bool`) -- Whether or not this should launch a distributed GPU training.
|
||||
* `--tpu` (`bool`) -- Whether or not this should launch a TPU training.
|
||||
* `--ipex` (`bool`) -- Whether or not this should launch an Intel Pytorch Extension (IPEX) training.
|
||||
* `--ipex` (`bool`) -- Whether or not this should launch an Intel Pytorch Extension (IPEX) training. **This argument is deprecated, will be removed in Accelerate v1.10**
|
||||
|
||||
**Resource Selection Arguments**:
|
||||
|
||||
@ -158,7 +158,7 @@ The following arguments are useful for selecting which training paradigm to use.
|
||||
* `--use_deepspeed` (`bool`) -- Whether or not to use DeepSpeed for training.
|
||||
* `--use_fsdp` (`bool`) -- Whether or not to use FullyShardedDataParallel for training.
|
||||
* `--use_megatron_lm` (`bool`) -- Whether or not to use Megatron-LM for training.
|
||||
* `--use_xpu` (`bool`) -- Whether to use IPEX plugin to speed up training on XPU specifically. **This argument is deprecated and ignored, will be removed in Accelerate v1.20**
|
||||
* `--use_xpu` (`bool`) -- Whether to use IPEX plugin to speed up training on XPU specifically. **This argument is deprecated and ignored, will be removed in Accelerate v1.10**
|
||||
|
||||
**Distributed GPU Arguments**:
|
||||
|
||||
|
||||
@ -29,6 +29,11 @@ rendered properly in your Markdown viewer.
|
||||
[[autodoc]] tracking.WandBTracker
|
||||
- __init__
|
||||
|
||||
## Trackio
|
||||
|
||||
[[autodoc]] tracking.TrackioTracker
|
||||
- __init__
|
||||
|
||||
## CometMLTracker
|
||||
|
||||
[[autodoc]] tracking.CometMLTracker
|
||||
@ -48,3 +53,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
[[autodoc]] tracking.ClearMLTracker
|
||||
- __init__
|
||||
|
||||
## SwanLabTracker
|
||||
|
||||
[[autodoc]] tracking.SwanLabTracker
|
||||
- __init__
|
||||
|
||||
@ -245,7 +245,7 @@ As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accu
|
||||
|
||||
> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values.
|
||||
|
||||
In other words, some adjustements must be made on losses that operate on a token-level basis.
|
||||
In other words, some adjustments must be made on losses that operate on a token-level basis.
|
||||
|
||||
### Skeleton code
|
||||
|
||||
@ -282,7 +282,7 @@ for update_step in range(total_updates):
|
||||
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
|
||||
|
||||
for i, batch in enumerate(batch_samples):
|
||||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
|
||||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
|
||||
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
|
||||
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
|
||||
ctx = model.no_sync
|
||||
@ -294,7 +294,7 @@ for update_step in range(total_updates):
|
||||
with ctx():
|
||||
inputs, targets = batch
|
||||
outputs = model(inputs)
|
||||
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging
|
||||
loss = loss_function(outputs, targets) # the loss function should sum over samples rather than averaging
|
||||
|
||||
# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
|
||||
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
|
||||
@ -394,7 +394,7 @@ for update_step in range(total_gradient_updates):
|
||||
for i, batch in enumerate(batch_samples):
|
||||
inputs, labels = batch["input_ids"], batch["labels"]
|
||||
total_batched_samples += 1
|
||||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
|
||||
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
|
||||
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
|
||||
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
|
||||
ctx = model.no_sync
|
||||
|
||||
@ -13,34 +13,11 @@ specific language governing permissions and limitations under the License.
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# Intel® Extension for PyTorch
|
||||
|
||||
[IPEX](https://github.com/intel/intel-extension-for-pytorch) is optimized for CPUs with AVX-512 or above, and functionally works for CPUs with only AVX2. So, it is expected to bring performance benefit for Intel CPU generations with AVX-512 or above while CPUs with only AVX2 (e.g., AMD CPUs or older Intel CPUs) might result in a better performance under IPEX, but not guaranteed. IPEX provides performance optimizations for CPU training with both Float32 and BFloat16. The usage of BFloat16 is the main focus of the following sections.
|
||||
|
||||
Low precision data type BFloat16 has been natively supported on the 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) with AVX512 instruction set and will be supported on the next generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix Extensions (Intel® AMX) instruction set with further boosted performance. The Auto Mixed Precision for CPU backend has been enabled since PyTorch-1.10. At the same time, the support of Auto Mixed Precision with BFloat16 for CPU and BFloat16 optimization of operators has been massively enabled in Intel® Extension for PyTorch, and partially upstreamed to PyTorch master branch. Users can get better performance and user experience with IPEX Auto Mixed Precision.
|
||||
|
||||
## IPEX installation:
|
||||
|
||||
IPEX release is following PyTorch, to install via pip:
|
||||
|
||||
| PyTorch Version | IPEX version |
|
||||
| :---------------: | :----------: |
|
||||
| 2.0 | 2.0.0 |
|
||||
| 1.13 | 1.13.0 |
|
||||
| 1.12 | 1.12.300 |
|
||||
| 1.11 | 1.11.200 |
|
||||
| 1.10 | 1.10.100 |
|
||||
|
||||
```
|
||||
pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
|
||||
```
|
||||
|
||||
Check more approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/installation.html).
|
||||
|
||||
# Training on Intel CPU
|
||||
|
||||
## How It Works For Training optimization in CPU
|
||||
|
||||
Accelerate has integrated [IPEX](https://github.com/intel/intel-extension-for-pytorch), all you need to do is enabling it through the config.
|
||||
Accelerate has full support for Intel CPU, all you need to do is enabling it through the config.
|
||||
|
||||
**Scenario 1**: Acceleration of No distributed CPU training
|
||||
|
||||
@ -55,7 +32,6 @@ This machine
|
||||
Which type of machine are you using?
|
||||
No distributed training
|
||||
Do you want to run your training on CPU only (even if a GPU / Apple Silicon device is available)? [yes/NO]:yes
|
||||
Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU? [yes/NO]:yes
|
||||
Do you wish to optimize your script with torch dynamo?[yes/NO]:NO
|
||||
Do you want to use DeepSpeed? [yes/NO]: NO
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
@ -69,15 +45,12 @@ default options when doing
|
||||
accelerate launch my_script.py --args_to_my_script
|
||||
```
|
||||
|
||||
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with IPEX enabled.
|
||||
default_config.yaml that is generated after `accelerate config`
|
||||
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with `default_config.yaml` which is generated by `accelerate config`
|
||||
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: 'NO'
|
||||
downcast_bf16: 'no'
|
||||
ipex_config:
|
||||
ipex: true
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
@ -117,7 +90,6 @@ What is the rank of this machine?
|
||||
What is the IP address of the machine that will host the main process? 36.112.23.24
|
||||
What is the port you will use to communicate with the main process? 29500
|
||||
Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: yes
|
||||
Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU? [yes/NO]:yes
|
||||
Do you want accelerate to launch mpirun? [yes/NO]: yes
|
||||
Please enter the path to the hostfile to use with mpirun [~/hostfile]: ~/hostfile
|
||||
Enter the number of oneCCL worker threads [1]: 1
|
||||
@ -129,13 +101,11 @@ bf16
|
||||
```
|
||||
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with IPEX enabled for distributed CPU training.
|
||||
|
||||
default_config.yaml that is generated after `accelerate config`
|
||||
`default_config.yaml` which is generated by `accelerate config`
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MULTI_CPU
|
||||
downcast_bf16: 'no'
|
||||
ipex_config:
|
||||
ipex: true
|
||||
machine_rank: 0
|
||||
main_process_ip: 36.112.23.24
|
||||
main_process_port: 29500
|
||||
@ -156,8 +126,10 @@ use_cpu: true
|
||||
|
||||
Set following env and using intel MPI to launch the training
|
||||
|
||||
In node0, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.
|
||||
If you selected to have Accelerate launch `mpirun`, ensure that the location of your hostfile matches the path in the config.
|
||||
In `node0`, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.
|
||||
|
||||
If you selected to let Accelerate launch `mpirun`, ensure that the location of your hostfile matches the path in the config.
|
||||
|
||||
```bash
|
||||
$ cat hostfile
|
||||
xxx.xxx.xxx.xxx #node0 ip
|
||||
@ -165,18 +137,18 @@ xxx.xxx.xxx.xxx #node1 ip
|
||||
xxx.xxx.xxx.xxx #node2 ip
|
||||
xxx.xxx.xxx.xxx #node3 ip
|
||||
```
|
||||
When Accelerate is launching `mpirun`, source the oneCCL bindings setvars.sh to get your Intel MPI environment, and then
|
||||
run your script using `accelerate launch`. Note that the python script and environment needs to exist on all of the
|
||||
machines being used for multi-CPU training.
|
||||
|
||||
Before executing `accelerate launch` command, you need source the oneCCL bindings `setvars.sh` to get your Intel MPI environment properly. Note that both the python script and environment need to be available on all of the machines being used for multi-CPU training.
|
||||
|
||||
```bash
|
||||
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
|
||||
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
|
||||
|
||||
accelerate launch examples/nlp_example.py
|
||||
```
|
||||
Otherwise, if you selected not to have Accelerate launch `mpirun`, run the following command in node0 and **16DDP** will
|
||||
be enabled in node0,node1,node2,node3 with BF16 mixed precision. When using this method, the python script, python
|
||||
environment, and accelerate config file need to be present on all of the machines used for multi-CPU training.
|
||||
|
||||
You can also directly launch distributed training with `mpirun` command, you need to run the following command in node0 and **16DDP** will be enabled in node0,node1,node2,node3 with BF16 mixed precision. When using this method, the python script, python environment, and accelerate config file need to be available on all of the machines used for multi-CPU training.
|
||||
|
||||
```bash
|
||||
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
|
||||
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
|
||||
@ -185,11 +157,3 @@ export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
mpirun -f hostfile -n 16 -ppn 4 accelerate launch examples/nlp_example.py
|
||||
```
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Project's github](https://github.com/intel/intel-extension-for-pytorch)
|
||||
- [API docs](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/api_doc.html)
|
||||
- [Tuning guide](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html)
|
||||
- [Blogs & Publications](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/blogs_publications.html)
|
||||
|
||||
@ -20,10 +20,11 @@ Accelerate provides a general tracking API that can be used to log useful items
|
||||
|
||||
## Integrated Trackers
|
||||
|
||||
Currently `Accelerate` supports seven trackers out-of-the-box:
|
||||
Currently `Accelerate` supports eight trackers out-of-the-box:
|
||||
|
||||
- TensorBoard
|
||||
- WandB
|
||||
- WandB
|
||||
- Trackio
|
||||
- CometML
|
||||
- Aim
|
||||
- MLFlow
|
||||
|
||||
@ -218,7 +218,7 @@ def parse_args():
|
||||
default="all",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
|
||||
' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
|
||||
@ -215,7 +215,7 @@ def parse_args():
|
||||
default="all",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
|
||||
' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."
|
||||
),
|
||||
)
|
||||
|
||||
@ -31,8 +31,8 @@ from accelerate.utils import ProfileKwargs
|
||||
#
|
||||
# This example trains a Bert base model on GLUE MRPC
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
# - single CPU or single device (CUDA GPU, Intel XPU etc.)
|
||||
# - multi devices (using PyTorch distributed mode)
|
||||
# - (multi) TPUs
|
||||
# - fp16 (mixed-precision) or fp32 (normal precision)
|
||||
#
|
||||
@ -183,7 +183,8 @@ def training_function(config, args):
|
||||
# New Code #
|
||||
accelerator.print(
|
||||
prof.key_averages().table(
|
||||
sort_by="self_cpu_time_total" if args.cpu else "self_cuda_time_total", row_limit=-1
|
||||
sort_by="self_cpu_time_total" if args.cpu else f"self_{accelerator.device.type}_time_total",
|
||||
row_limit=-1,
|
||||
)
|
||||
)
|
||||
|
||||
@ -215,7 +216,7 @@ def main():
|
||||
choices=["no", "fp16", "bf16", "fp8"],
|
||||
help="Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU.",
|
||||
"and an Nvidia Ampere GPU or an Intel XPU.",
|
||||
)
|
||||
# New Code #
|
||||
parser.add_argument(
|
||||
|
||||
@ -8,7 +8,7 @@ deepspeed_config:
|
||||
# `transformers` uses the right `init` function
|
||||
zero3_init_flag: false # true
|
||||
|
||||
# Finally we need to specify the number of GPUs to use
|
||||
# Finally we need to specify the number of accelerators to use
|
||||
num_processes: 2
|
||||
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
|
||||
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file
|
||||
|
||||
@ -11,8 +11,8 @@ fp8_config:
|
||||
fp8_format: E4M3
|
||||
interval: 1
|
||||
margin: 0
|
||||
override_linear_precision: (false, false, false)
|
||||
override_linear_precision: [false, false, false]
|
||||
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
|
||||
use_autocast_during_eval: false
|
||||
# If using MS-AMP, we ignore all of the prior and set a opt_level
|
||||
#opt_level: O1
|
||||
#opt_level: O1
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
|
||||
# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP
|
||||
distributed_type: FSDP
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
|
||||
mixed_precision: 'bf16'
|
||||
# Specify the number of GPUs to use
|
||||
# Specify the number of accelerators to use
|
||||
num_processes: 2
|
||||
# Then we can specify the FSDP config
|
||||
fsdp_config:
|
||||
|
||||
6
examples/config_yaml_templates/multi_xpu.yaml
Normal file
6
examples/config_yaml_templates/multi_xpu.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
# Specify distributed_type as `MULTI_XPU` for DDP
|
||||
distributed_type: "MULTI_XPU"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
# Specify the number of XPUs to use
|
||||
num_processes: 2
|
||||
@ -1,4 +1,4 @@
|
||||
# Since this is single GPU, we don't need distributed training
|
||||
# Since this is single GPU/XPU, we don't need distributed training
|
||||
distributed_type: "NO"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
mixed_precision: "bf16"
|
||||
@ -1,58 +0,0 @@
|
||||
# FSDP2 Examples
|
||||
|
||||
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
|
||||
|
||||
## FSDP2 + ao Float8Linear (`fsdp2_fp8.py`)
|
||||
|
||||
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
|
||||
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
|
||||
gaining even more speed and memory savings, as `ao` doesn't ship with any kernels by default, so we have to gain the performance from compiling the model.
|
||||
|
||||
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
|
||||
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
|
||||
|
||||
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS raise by using FP8.
|
||||
|
||||
<div style="display: flex; gap: 25px;">
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tps.png" alt="tps" style="width: 100%;">
|
||||
<p style="text-align: center; margin-top: 8px;">TPs per device, bf16 vs fp8</p>
|
||||
</div>
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tflops.png" alt="tflops" style="width: 100%;">
|
||||
<p style="text-align: center; margin-top: 8px;">TFLOPS per device, bf16 vs fp8. We cannot really compare MFU as fp8 tensor cores are used as well.</p>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_loss.png" alt="loss" style="width: 100%; max-width: 900px;">
|
||||
<p style="text-align: center; margin-top: 8px;">Loss curve, bf16 vs fp8, it's hard to see the difference as the curves mostly overlap</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length and 1000 steps. To run the example, you can use the following command, where you can specify the precision to train in:
|
||||
|
||||
```bash
|
||||
accelerate launch --fsdp2_fp8.py --sequence_length 8192 --num_steps 1000 --log_with wandb --precision [fp8 | bf16]
|
||||
```
|
||||
|
||||
## FSDP2 + context parallelism (`fsdp2_context_parallel.py`)
|
||||
|
||||
In this file, we showcase integration of context parallelism with FSDP2. Context parallelism is a technique that allows us to scale the training to sequence length of up to a million tokens. With `accelerator.context_parallel` context manager, we replace the attention implementation with a context parallel version, which enables us to train on a sequence length of up to 128k tokens on 8x H100 GPUs, with possibility of endless scaling if we have enough GPUs.
|
||||
|
||||
For a detailed explanation and more details, please refer to [this guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --fsdp2_context_parallel.py --sequence_length 128000 --num_steps 1000 --log_with wandb --cp_size 8 --cp_comm_strategy allgather
|
||||
```
|
||||
|
||||
More details about the context parallelism can be found in the [concept guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can see some results below:
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
|
||||
<br>
|
||||
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with Context Parallel using FSDP2 via Accelerate.
|
||||
This example demonstrates how to use Accelerate's context_parallel feature for efficient long sequence training.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, setup_tokenizer
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sequence-length", type=int, default=128_000, help="Sequence length for the dataset")
|
||||
parser.add_argument("--num-steps", type=int, default=100, help="Number of training steps")
|
||||
parser.add_argument("--log-with", type=str, default="wandb", help="Logging service to use")
|
||||
parser.add_argument("--cp-size", type=int, default=8, help="Context parallel size")
|
||||
parser.add_argument("--cp-comm-strategy", type=str, default="allgather", help="Context parallel shard rotation")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def training_step(batch, model, optimizer, accelerator: Accelerator):
|
||||
"""
|
||||
Perform a single training step with context parallel.
|
||||
|
||||
Args:
|
||||
batch: Input batch containing input_ids and labels
|
||||
model: The model to train
|
||||
optimizer: Optimizer
|
||||
accelerator: Accelerator instance
|
||||
|
||||
Returns:
|
||||
loss: Training loss
|
||||
"""
|
||||
|
||||
# Use context parallel for efficient long sequence processing
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def main():
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
|
||||
# Configure FSDP2 plugin
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
|
||||
cpu_ram_efficient_loading=True,
|
||||
activation_checkpointing=True,
|
||||
fsdp_version=2,
|
||||
cp_size=args.cp_size,
|
||||
cp_comm_strategy=args.cp_comm_strategy,
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
log_with=args.log_with,
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision="bf16",
|
||||
)
|
||||
|
||||
accelerator.init_trackers(
|
||||
project_name="FSDP2_context_parallel",
|
||||
config={
|
||||
"sequence_length": args.sequence_length,
|
||||
"num_steps": args.num_steps,
|
||||
"cp_size": args.cp_size,
|
||||
"cp_comm_strategy": args.cp_comm_strategy,
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenizer = setup_tokenizer(MODEL_ID)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
accelerator.print("Preparing dataset... this might take a while")
|
||||
dataset = get_dataset(
|
||||
accelerator,
|
||||
tokenizer,
|
||||
args.sequence_length,
|
||||
processing_batch_size=args.sequence_length
|
||||
// 20, # we need to override the default processing batch size to avoid empty packed sequences
|
||||
)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print(f"Starting training with context parallel for {total_num_steps} steps...")
|
||||
accelerator.print(f"Sequence length: {args.sequence_length}")
|
||||
accelerator.print("Warming up for 10 steps...")
|
||||
|
||||
accelerator.print(
|
||||
"Each step takes ~10 seconds with default settings on 8x H100 SXM GPUs, seeing logs takes a while"
|
||||
)
|
||||
for step, batch in enumerate(dataloader):
|
||||
print(f"Step {step}")
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
# get number of tokens before context_parallel shards the batch
|
||||
batch_tokens = batch["input_ids"].shape[0] * batch["input_ids"].shape[1]
|
||||
|
||||
loss = training_step(batch, model, optimizer, accelerator)
|
||||
|
||||
# each batch gets the same data, we divide by the number of processes to get the number of tokens per process
|
||||
metrics = performance_tracker.step(batch_tokens // accelerator.num_processes)
|
||||
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warmup completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
log_metrics.update(
|
||||
{
|
||||
"tokens_per_second": int(metrics["tokens_per_second"]),
|
||||
"steps_per_second": metrics["steps_per_second"],
|
||||
}
|
||||
)
|
||||
|
||||
if (step % 10 == 0 or step == total_num_steps - 1) and metrics:
|
||||
accelerator.print(
|
||||
f"Step {step}/{total_num_steps} | "
|
||||
f"Loss: {loss.item():.4f} | "
|
||||
f"Tokens/s: {int(metrics['tokens_per_second'])} | "
|
||||
f"Steps/s: {metrics['steps_per_second']:.2f} | "
|
||||
)
|
||||
|
||||
accelerator.log(log_metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
accelerator.print("Training completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -177,6 +177,7 @@ def training_function(config, args):
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||
print(f"===== {predictions}")
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
accelerate # used to be installed in Amazon SageMaker environment
|
||||
evaluate
|
||||
datasets==2.3.2
|
||||
datasets
|
||||
schedulefree
|
||||
huggingface_hub>=0.20.0
|
||||
|
||||
77
examples/torch_native_parallelism/README.md
Normal file
77
examples/torch_native_parallelism/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
## Torch Native Parallelism
|
||||
|
||||
With recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
|
||||
This folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.
|
||||
|
||||
### ND Parallelism
|
||||
|
||||
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
|
||||
Accelerate then takes care of everything else, such as data parallelism, FSDP or context parallelism.
|
||||
Script `nd_parallel.py` showcases this. We enable you to configure 4 different parallel dimensions (for now 👀):
|
||||
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
|
||||
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
|
||||
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
|
||||
- cp_size: how many devices to use for context parallelism, this will also shard the model, optimizer and gradients using `FSDP2` across
|
||||
the same group of devices, to further optimize memory usage (this comes with no slowdown)
|
||||
|
||||
For example, with 8 nodes, you can run the script as such:
|
||||
```bash
|
||||
accelerate launch --num-processes 8 nd_parallel.py \
|
||||
--dp-replicate-size 2 \
|
||||
--dp-shard-size 2 \
|
||||
--tp-size 2 \
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Only use TP intra-node - therefore max TP size you should need is 8. You can also use a lower size, as FSDP (`--dp-shard-size`) can be faster on smaller models with
|
||||
shorter sequence lengths. If you cannot fit your model into memory, utilize `--dp-shard-size` as much as you can. Afterwards, to scale up and utilize all your resources, use `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](https://huggingface.co/blog/accelerate-nd-parallel), or if you really want to dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
</Tip>
|
||||
|
||||
This feature is also fully integrated into 🤗 transformers `Trainer`. To use it, simply launch your script with path to your accelerate configuration file. You can see a minimal example of such script in `nd_parallel_trainer.py`.
|
||||
We provide 2 pre-configured configuration files:
|
||||
|
||||
#### HSDP + TP (3D parallelism)
|
||||
|
||||
```bash
|
||||
accelerate launch --config-file configs/tp_hsdp.yaml nd_parallel_trainer.py
|
||||
```
|
||||
|
||||
#### Context parallelism (128k sequence length)
|
||||
|
||||
```bash
|
||||
accelerate launch --config-file configs/cp.yaml nd_parallel_trainer.py --sequence-length=128000
|
||||
```
|
||||
|
||||
### FSDP2 + ao Float8Linear
|
||||
|
||||
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
|
||||
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
|
||||
gaining even more speed and memory savings, as `ao` doesn't ship with any kernels by default, so we have to gain the performance from compiling the model.
|
||||
|
||||
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
|
||||
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
|
||||
|
||||
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.
|
||||
|
||||
<div style="display: flex; gap: 25px;">
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tps.png" alt="tps" style="width: 100%;">
|
||||
<p style="text-align: center; margin-top: 8px;">TPS per device, BF16 vs FP8</p>
|
||||
</div>
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tflops.png" alt="tflops" style="width: 100%;">
|
||||
<p style="text-align: center; margin-top: 8px;">TFLOPS per device, BF16 vs FP8. We cannot really compare MFU as FP8 tensor cores are used as well.</p>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; width: 49%;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_loss.png" alt="loss" style="width: 100%; max-width: 900px;">
|
||||
<p style="text-align: center; margin-top: 8px;">Loss curve, BF16 vs FP8, it's hard to see the difference as the curves mostly overlap</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length and 1000 steps. To run the example, you can use the following command, where you can specify the precision to train in:
|
||||
|
||||
```bash
|
||||
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
|
||||
```
|
||||
|
||||
29
examples/torch_native_parallelism/configs/cp.yaml
Normal file
29
examples/torch_native_parallelism/configs/cp.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
parallelism_config:
|
||||
parallelism_config_cp_size: 8
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
29
examples/torch_native_parallelism/configs/tp_hsdp.yaml
Normal file
29
examples/torch_native_parallelism/configs/tp_hsdp.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
parallelism_config:
|
||||
parallelism_config_cp_size: 1
|
||||
parallelism_config_dp_replicate_size: 2
|
||||
parallelism_config_dp_shard_size: 2
|
||||
parallelism_config_tp_size: 2
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@ -22,13 +22,15 @@ import argparse
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import AORecipeKwargs, FullyShardedDataParallelPlugin, TorchDynamoPlugin, set_seed
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token, setup_tokenizer
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token
|
||||
|
||||
|
||||
WARMUP_STEPS = 10
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
@ -66,7 +68,6 @@ def main():
|
||||
|
||||
fp8_config = Float8LinearConfig(
|
||||
enable_fsdp_float8_all_gather=True, # extra saving by gathering parameters in fp8 and upcasting after
|
||||
force_recompute_fp8_weight_in_bwd=True,
|
||||
)
|
||||
|
||||
kwargs = []
|
||||
@ -89,24 +90,22 @@ def main():
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
tokenizer = setup_tokenizer(MODEL_ID)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
dataset = get_dataset(tokenizer, args.sequence_length, accelerator)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print(f"Starting training with {args.precision} precision for {total_num_steps} steps...")
|
||||
accelerator.print(f"Sequence length: {args.sequence_length}")
|
||||
accelerator.print("Warming up for 10 steps...")
|
||||
performance_tracker = PerformanceTracker(warmup_steps=5)
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
@ -118,35 +117,18 @@ def main():
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
batch_tokens = batch["input_ids"].shape[1]
|
||||
metrics = performance_tracker.step(batch_tokens)
|
||||
metrics = performance_tracker.step(batch["input_ids"].shape[1], model_flops_per_token)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
accelerator.print("Warm up completed! Starting training")
|
||||
elif metrics:
|
||||
tps = metrics["tokens_per_second"]
|
||||
tflops = metrics["total_tokens"] * model_flops_per_token / (metrics["total_time"] * 1e12)
|
||||
|
||||
# it's rather hard to get a good estimate of MFU as we train with FP8, so both FP8 and BF16 tensor cores are used, therefore we just report TFLOPS (Tera floating point operations per second)
|
||||
# Given H100 SXM, the theoretical peak flops are ~990 TFLOPS for bf16 and ~1980 TFLOPS for fp8 [https://resources.nvidia.com/en-us-gpu-resources/h100-datasheet-24306]
|
||||
# This is WITH sparsity, so we divide by 2 to get the answer w/o sparsity
|
||||
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | TPS per device: {tps:.2f} | TFLOPS per device: {tflops:.2f}"
|
||||
log_metrics.update(
|
||||
{
|
||||
"steps_per_second": metrics["steps_per_second"],
|
||||
"tps_per_device": tps,
|
||||
"tflops_per_device": tflops,
|
||||
}
|
||||
)
|
||||
print_msg += performance_tracker.get_print_message(metrics)
|
||||
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
accelerator.log(log_metrics)
|
||||
accelerator.log(metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
173
examples/torch_native_parallelism/nd_parallel.py
Normal file
173
examples/torch_native_parallelism/nd_parallel.py
Normal file
@ -0,0 +1,173 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with ND parallel using accelerate's ParallelismConfig
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from utils import (
|
||||
PerformanceTracker,
|
||||
create_collate_fn,
|
||||
get_dataset,
|
||||
get_model_flops_per_token,
|
||||
setup_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dp-replicate-size", type=int, default=1)
|
||||
parser.add_argument("--dp-shard-size", type=int, default=1)
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--cp-size", type=int, default=1)
|
||||
parser.add_argument("--sequence-length", type=int, default=1024)
|
||||
parser.add_argument("--num-steps", type=int, default=1000)
|
||||
parser.add_argument("--save-dir", type=str, default="./outputs")
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def forward(model, batch, optimizer, accelerator: Accelerator):
|
||||
batch["position_ids"] = torch.arange(0, batch["input_ids"].size(1), device=batch["input_ids"].device).unsqueeze(0)
|
||||
# We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
|
||||
# itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
|
||||
buffers = [batch["input_ids"], batch["shift_labels"], batch["labels"], batch["position_ids"]]
|
||||
with accelerator.maybe_context_parallel(
|
||||
buffers=buffers, buffer_seq_dims=[1, 1, 1, 1], no_restore_buffers=set(buffers)
|
||||
):
|
||||
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
|
||||
# As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device
|
||||
# I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts
|
||||
loss_reduce_grp = (
|
||||
accelerator.torch_device_mesh["dp_cp"].get_group()
|
||||
if accelerator.parallelism_config.dp_cp_dim_names
|
||||
else None
|
||||
)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def train(args):
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
dp_shard_size=args.dp_shard_size,
|
||||
tp_size=args.tp_size,
|
||||
cp_size=args.cp_size,
|
||||
)
|
||||
|
||||
# FSDP needs extra configuration, so we properly shard the model
|
||||
fsdp2_plugin = None
|
||||
if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["Qwen3DecoderLayer"],
|
||||
state_dict_type="SHARDED_STATE_DICT",
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with=["wandb"], mixed_precision="bf16", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin
|
||||
)
|
||||
accelerator.init_trackers("nd_parallel_training")
|
||||
|
||||
# If TP was enabled, we need to tell transformers to prepare the model for us
|
||||
model_kwargs = (
|
||||
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
|
||||
if args.tp_size > 1
|
||||
else {}
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
tokenizer = setup_tokenizer(args.model_name)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
|
||||
dataset = get_dataset(tokenizer, args.sequence_length, accelerator)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
|
||||
performance_tracker = PerformanceTracker(warmup_steps=5)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
loss = forward(model, batch, optimizer, accelerator)
|
||||
|
||||
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
|
||||
metrics = performance_tracker.step(
|
||||
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size,
|
||||
model_flops_per_token=model_flops_per_token,
|
||||
)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
|
||||
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
|
||||
accelerator.print(f"Saving checkpoint at step {step}...")
|
||||
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
|
||||
|
||||
accelerator.log({"loss": loss.item()})
|
||||
|
||||
accelerator.print("Training completed!")
|
||||
|
||||
model.save_pretrained(args.save_dir + f"/{args.model_name}")
|
||||
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
if args.dp_shard_size == 1 and args.tp_size > 1:
|
||||
# We currently don't support saving with `save_state` when using only
|
||||
# tensor parallelism, fsdp must be enabled
|
||||
warnings.warn(
|
||||
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
|
||||
)
|
||||
train(args)
|
||||
82
examples/torch_native_parallelism/nd_parallel_trainer.py
Normal file
82
examples/torch_native_parallelism/nd_parallel_trainer.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
|
||||
from accelerate.utils import ParallelismConfig
|
||||
from utils import get_dataset
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sequence-length", type=int, default=4096)
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
|
||||
parser.add_argument("--device-type", type=str, default="auto")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
# If ParallelismConfig is not initialized with __init__, it reads from env vars
|
||||
# which were set by using config
|
||||
pc = ParallelismConfig()
|
||||
args = parse_args()
|
||||
|
||||
if args.device_type == "auto":
|
||||
args.device_type = torch.accelerator.current_accelerator().type
|
||||
|
||||
model_kwargs = {}
|
||||
if pc.tp_enabled:
|
||||
model_kwargs["tp_plan"] = "auto"
|
||||
model_kwargs["device_mesh"] = pc.build_device_mesh(args.device_type)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, **model_kwargs)
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
packed_dataset = get_dataset(tokenizer, args.sequence_length)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.save_dir,
|
||||
parallelism_config=pc,
|
||||
num_train_epochs=1,
|
||||
per_device_train_batch_size=1,
|
||||
logging_steps=5,
|
||||
save_steps=args.checkpoint_frequency,
|
||||
learning_rate=5e-5,
|
||||
remove_unused_columns=False,
|
||||
bf16=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=packed_dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -13,10 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Common utilities for FSDP2 examples.
|
||||
Common utilities for torch-native-parallelism examples.
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
@ -25,12 +26,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def get_dataset(
|
||||
accelerator: Accelerator,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_len: int,
|
||||
processing_batch_size: int = 1000,
|
||||
) -> Dataset:
|
||||
def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Accelerator | None = None) -> Dataset:
|
||||
"""
|
||||
Load and prepare TinyStories dataset.
|
||||
|
||||
@ -38,11 +34,11 @@ def get_dataset(
|
||||
accelerator (Accelerator): Accelerate accelerator instance
|
||||
tokenizer (AutoTokenizer): Hugging Face tokenizer
|
||||
seq_len (int): Sequence length for the dataset
|
||||
processing_batch_size (int): Batch size for processing the dataset
|
||||
|
||||
Returns:
|
||||
Dataset: Packed dataset
|
||||
"""
|
||||
processing_ctx = accelerator.main_process_first if accelerator else nullcontext
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
|
||||
|
||||
def tokenize_function(examples):
|
||||
@ -56,10 +52,8 @@ def get_dataset(
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_dataset = raw_dataset.map(
|
||||
tokenize_function, batched=True, remove_columns=["text"], batch_size=processing_batch_size
|
||||
)
|
||||
with processing_ctx():
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
def create_packed_sequences(examples):
|
||||
all_tokens = []
|
||||
@ -69,6 +63,7 @@ def get_dataset(
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
packed_position_ids = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
@ -76,15 +71,21 @@ def get_dataset(
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
packed_labels.append(full_sequence[1:])
|
||||
packed_position_ids.append(torch.arange(0, seq_len))
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
return {
|
||||
"input_ids": packed_input_ids,
|
||||
"shift_labels": packed_labels,
|
||||
"position_ids": packed_position_ids,
|
||||
"labels": packed_labels,
|
||||
}
|
||||
|
||||
with accelerator.main_process_first():
|
||||
with processing_ctx():
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=processing_batch_size,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
return packed_dataset.shuffle(seed=42)
|
||||
@ -119,8 +120,8 @@ def create_collate_fn():
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}
|
||||
|
||||
return collate_fn
|
||||
|
||||
@ -139,7 +140,7 @@ class PerformanceTracker:
|
||||
self.is_in_warmup = True
|
||||
self.step_count = 0
|
||||
|
||||
def step(self, batch_tokens: int) -> dict:
|
||||
def step(self, batch_tokens: int, model_flops_per_token: float | None = None) -> dict:
|
||||
"""
|
||||
Update performance tracking with a new step.
|
||||
|
||||
@ -158,20 +159,43 @@ class PerformanceTracker:
|
||||
return {"warmup_completed": True}
|
||||
|
||||
if not self.is_in_warmup and self.start_time is not None:
|
||||
dct = {}
|
||||
self.num_tokens += batch_tokens
|
||||
total_time = time.perf_counter() - self.start_time
|
||||
steps_from_warmup = self.step_count - self.warmup_steps
|
||||
|
||||
if total_time > 0 and steps_from_warmup > 0:
|
||||
return {
|
||||
memory_stats = gpu_memory_usage_all()
|
||||
dct = {
|
||||
"tokens_per_second": self.num_tokens / total_time,
|
||||
"steps_per_second": steps_from_warmup / total_time,
|
||||
"total_tokens": self.num_tokens,
|
||||
"total_time": total_time,
|
||||
**memory_stats,
|
||||
}
|
||||
|
||||
if model_flops_per_token is not None:
|
||||
flops = model_flops_per_token * self.num_tokens
|
||||
dct["tflops_per_device"] = flops / (total_time * 1e12)
|
||||
|
||||
return dct
|
||||
|
||||
return {}
|
||||
|
||||
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
|
||||
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}"
|
||||
if "tflops_per_device" in metrics:
|
||||
print_msg += f" | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
|
||||
else:
|
||||
print_msg += "\n"
|
||||
if with_memory:
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
f"alloc={metrics['peak_memory_alloc']:.1f}, "
|
||||
f"reserved={metrics['peak_memory_reserved']:.1f}"
|
||||
)
|
||||
return print_msg
|
||||
|
||||
|
||||
def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
"""Setup tokenizer with proper padding token."""
|
||||
@ -179,3 +203,21 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
device_type = torch.accelerator.current_accelerator().type
|
||||
device = torch.device(f"{device_type}:{device}")
|
||||
torch_device_module = getattr(torch, device_type, torch.cuda)
|
||||
_BYTES_IN_GIB = 1024**3
|
||||
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
memory_stats = {
|
||||
"peak_memory_active": peak_memory_active,
|
||||
"peak_memory_alloc": peak_memory_alloc,
|
||||
"peak_memory_reserved": peak_memory_reserved,
|
||||
}
|
||||
torch_device_module.reset_peak_memory_stats(device)
|
||||
|
||||
return memory_stats
|
||||
13
setup.py
13
setup.py
@ -41,7 +41,16 @@ extras["deepspeed"] = ["deepspeed"]
|
||||
extras["rich"] = ["rich"]
|
||||
|
||||
extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly
|
||||
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"]
|
||||
extras["test_trackers"] = [
|
||||
"wandb",
|
||||
"comet-ml",
|
||||
"tensorboard",
|
||||
"dvclive",
|
||||
"mlflow",
|
||||
"matplotlib",
|
||||
"swanlab",
|
||||
"trackio",
|
||||
]
|
||||
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
|
||||
|
||||
extras["sagemaker"] = [
|
||||
@ -50,7 +59,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="1.8.0.dev0",
|
||||
version="1.11.0.dev0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "1.8.0.dev0"
|
||||
__version__ = "1.11.0.dev0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .big_modeling import (
|
||||
@ -26,6 +26,7 @@ from .big_modeling import (
|
||||
from .data_loader import skip_first_batches
|
||||
from .inference import prepare_pippy
|
||||
from .launchers import debug_launcher, notebook_launcher
|
||||
from .parallelism_config import ParallelismConfig
|
||||
from .state import PartialState
|
||||
from .utils import (
|
||||
AutocastKwargs,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -747,3 +747,43 @@ def _attach_layerwise_casting_hooks(
|
||||
non_blocking,
|
||||
_prefix=layer_name,
|
||||
)
|
||||
|
||||
|
||||
def _attach_context_parallel_hooks(
|
||||
model: nn.Module,
|
||||
):
|
||||
"""
|
||||
Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
|
||||
|
||||
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
|
||||
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
|
||||
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
|
||||
not support attention masks. This function modifies the model in place.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to attach the hooks to.
|
||||
|
||||
"""
|
||||
|
||||
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
|
||||
if "attention_mask" in module_kwargs:
|
||||
module_kwargs["attention_mask"] = None
|
||||
module_kwargs["is_causal"] = True
|
||||
|
||||
return module_args, module_kwargs
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
|
||||
# Then these cases can happen:
|
||||
# 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
|
||||
# there's no attention mask kwarg -> hook is a no-op
|
||||
# 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
|
||||
# attention mask kwarg is passed -> hook will remove the attention mask and add
|
||||
# `is_causal=True` kwarg, which either crashes the training or fixes it
|
||||
# (training would crash anyway as attention mask isn't supported)
|
||||
# 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
|
||||
# a no-op as well
|
||||
if name.endswith("self_attn"):
|
||||
# we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
|
||||
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)
|
||||
|
||||
@ -505,17 +505,48 @@ def get_cluster_input():
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
if fsdp_version == 2:
|
||||
fsdp_config["fsdp_cp_size"] = _ask_field(
|
||||
"What should be your FSDP's context parallel size? (Input 1 or leave blank for no context parallel) [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
parallelism_config = {}
|
||||
|
||||
if fsdp_version == 2 and fsdp_config.get("fsdp_cp_size", 1) != 1:
|
||||
fsdp_config["fsdp_cp_comm_strategy"] = _ask_options(
|
||||
"What should be your FSDP's context parallel communication strategy? [allgather]: ",
|
||||
if fsdp_config.get("fsdp_version", 1) == 2:
|
||||
use_parallelism_config = _ask_field(
|
||||
"Do you want to use the parallelism config? [yes/NO]: ",
|
||||
_convert_yes_no_to_bool,
|
||||
default=False,
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
if use_parallelism_config:
|
||||
prefix = "parallelism_config_"
|
||||
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
|
||||
"What is the data parallelism replicate size? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
|
||||
"What is the FSDP shard size? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
parallelism_config[prefix + "tp_size"] = _ask_field(
|
||||
"What is the tensor parallelism size? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
parallelism_config[prefix + "cp_size"] = _ask_field(
|
||||
"What is the context parallelism size? [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
if parallelism_config[prefix + "cp_size"] > 1:
|
||||
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
|
||||
"What is the compute parallelism communication strategy?",
|
||||
["allgather", "alltoall"],
|
||||
lambda x: ["allgather", "alltoall"][int(x)],
|
||||
default=0,
|
||||
@ -790,8 +821,8 @@ def get_cluster_input():
|
||||
)
|
||||
fp8_config["fp8_format"] = _ask_options(
|
||||
"Which weight format should be used?",
|
||||
["HYBRID", "E4M3"],
|
||||
lambda x: "HYBRID" if x == 0 else "E4M3",
|
||||
["HYBRID", "E4M3", "E5M2"],
|
||||
lambda i: ["HYBRID", "E4M3", "E5M2"][i],
|
||||
default=0,
|
||||
)
|
||||
fp8_config["amax_history_length"] = _ask_field(
|
||||
@ -865,6 +896,7 @@ def get_cluster_input():
|
||||
fp8_config=fp8_config,
|
||||
deepspeed_config=deepspeed_config,
|
||||
fsdp_config=fsdp_config,
|
||||
parallelism_config=parallelism_config,
|
||||
megatron_lm_config=megatron_lm_config,
|
||||
ipex_config=ipex_config,
|
||||
mpirun_config=mpirun_config,
|
||||
|
||||
@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
|
||||
deepspeed_config: dict = None
|
||||
# args for fsdp
|
||||
fsdp_config: dict = None
|
||||
# args for parallelism config
|
||||
parallelism_config: dict = None
|
||||
# args for megatron_lm
|
||||
megatron_lm_config: dict = None
|
||||
# args for ipex
|
||||
@ -229,6 +231,8 @@ class ClusterConfig(BaseConfig):
|
||||
self.mpirun_config = {}
|
||||
if self.fp8_config is None:
|
||||
self.fp8_config = {}
|
||||
if self.parallelism_config is None:
|
||||
self.parallelism_config = {}
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
|
||||
@ -60,4 +60,4 @@ def update_command_parser(parser, parents):
|
||||
|
||||
def update_config_command(args):
|
||||
config_file = update_config(args)
|
||||
print(f"Sucessfully updated the configuration file at {config_file}.")
|
||||
print(f"Successfully updated the configuration file at {config_file}.")
|
||||
|
||||
@ -182,13 +182,6 @@ def launch_command_parser(subparsers=None):
|
||||
hardware_args.add_argument(
|
||||
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
|
||||
)
|
||||
hardware_args.add_argument(
|
||||
"--ipex",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether or not this should launch a Intel PyTorch Extension (IPEX) training.",
|
||||
)
|
||||
|
||||
# Resource selection arguments
|
||||
resource_args = parser.add_argument_group(
|
||||
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
|
||||
@ -269,6 +262,12 @@ def launch_command_parser(subparsers=None):
|
||||
action="store_true",
|
||||
help="Whether to use fsdp.",
|
||||
)
|
||||
paradigm_args.add_argument(
|
||||
"--use_parallelism_config",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use the parallelism config to configure the N-d distributed training.",
|
||||
)
|
||||
paradigm_args.add_argument(
|
||||
"--use_megatron_lm",
|
||||
default=False,
|
||||
@ -494,13 +493,13 @@ def launch_command_parser(subparsers=None):
|
||||
"--deepspeed_exclusion_filter",
|
||||
default=None,
|
||||
type=str,
|
||||
help="DeepSpeed exclusion filter string when using mutli-node setup.",
|
||||
help="DeepSpeed exclusion filter string when using multi-node setup.",
|
||||
)
|
||||
deepspeed_args.add_argument(
|
||||
"--deepspeed_inclusion_filter",
|
||||
default=None,
|
||||
type=str,
|
||||
help="DeepSpeed inclusion filter string when using mutli-node setup.",
|
||||
help="DeepSpeed inclusion filter string when using multi-node setup.",
|
||||
)
|
||||
deepspeed_args.add_argument(
|
||||
"--deepspeed_multinode_launcher",
|
||||
@ -586,7 +585,7 @@ def launch_command_parser(subparsers=None):
|
||||
"--fsdp_use_orig_params",
|
||||
default="true",
|
||||
type=str,
|
||||
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
|
||||
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
|
||||
" (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
@ -610,18 +609,6 @@ def launch_command_parser(subparsers=None):
|
||||
type=str,
|
||||
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_cp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="FSDP's context parallel size. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to 1 (CP not applied).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_cp_comm_strategy",
|
||||
type=str,
|
||||
default="allgather",
|
||||
help="FSDP's context parallel communication strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to `allgather`.",
|
||||
)
|
||||
|
||||
# megatron_lm args
|
||||
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
||||
@ -704,8 +691,8 @@ def launch_command_parser(subparsers=None):
|
||||
fp8_args.add_argument(
|
||||
"--fp8_format",
|
||||
type=str,
|
||||
default="E4M3",
|
||||
choices=["E4M3", "HYBRID"],
|
||||
default="HYBRID",
|
||||
choices=["HYBRID", "E4M3", "E5M2"],
|
||||
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
|
||||
)
|
||||
fp8_args.add_argument(
|
||||
@ -779,6 +766,45 @@ def launch_command_parser(subparsers=None):
|
||||
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
|
||||
)
|
||||
|
||||
# ParallelismConfig arguments
|
||||
parallelism_config_args = parser.add_argument_group(
|
||||
"ParallelismConfig Arguments",
|
||||
"Arguments related to the ParallelismConfig used for distributed training.",
|
||||
)
|
||||
parallelism_config_args.add_argument(
|
||||
"--parallelism_config_dp_replicate_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
|
||||
)
|
||||
|
||||
parallelism_config_args.add_argument(
|
||||
"--parallelism_config_dp_shard_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
|
||||
)
|
||||
|
||||
parallelism_config_args.add_argument(
|
||||
"--parallelism_config_tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
|
||||
)
|
||||
|
||||
parallelism_config_args.add_argument(
|
||||
"--parallelism_config_cp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
||||
)
|
||||
parallelism_config_args.add_argument(
|
||||
"--parallelism_config_cp_comm_strategy",
|
||||
type=str,
|
||||
default="allgather",
|
||||
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
|
||||
)
|
||||
|
||||
# Other arguments of the training scripts
|
||||
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
|
||||
|
||||
@ -1006,6 +1032,9 @@ def _validate_launch_command(args):
|
||||
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
|
||||
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
|
||||
|
||||
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
|
||||
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
|
||||
|
||||
defaults = None
|
||||
warned = []
|
||||
mp_from_config_flag = False
|
||||
@ -1039,6 +1068,7 @@ def _validate_launch_command(args):
|
||||
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
|
||||
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
|
||||
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
|
||||
args.use_parallelism_config = defaults.parallelism_config != {}
|
||||
if args.gpu_ids is None:
|
||||
if defaults.gpu_ids is not None:
|
||||
args.gpu_ids = defaults.gpu_ids
|
||||
|
||||
@ -89,7 +89,7 @@ def convert_config_to_fsdp2(config: dict) -> dict:
|
||||
new_fsdp_config = {}
|
||||
|
||||
if fsdp_config.get("fsdp_version", 1) == 2:
|
||||
logger.warning("Config already specfies FSDP2, skipping conversion...")
|
||||
logger.warning("Config already specifies FSDP2, skipping conversion...")
|
||||
logger.warning(
|
||||
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
|
||||
)
|
||||
|
||||
@ -32,6 +32,7 @@ from .utils import (
|
||||
find_batch_size,
|
||||
get_data_structure,
|
||||
initialize_tensors,
|
||||
is_datasets_available,
|
||||
is_torch_version,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
send_to_device,
|
||||
@ -74,7 +75,7 @@ class SeedableRandomSampler(RandomSampler):
|
||||
Same as a random sampler, except that in `__iter__` a seed can be used.
|
||||
|
||||
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
||||
and be fully reproducable on multiple iterations.
|
||||
and be fully reproducible on multiple iterations.
|
||||
|
||||
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
||||
(stored in `self.epoch`).
|
||||
@ -407,7 +408,7 @@ class DataLoaderStateMixin:
|
||||
class DataLoaderAdapter:
|
||||
"""
|
||||
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
||||
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
||||
compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
||||
@ -450,8 +451,8 @@ class DataLoaderAdapter:
|
||||
@property
|
||||
def __class__(self):
|
||||
"""
|
||||
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
object.
|
||||
"""
|
||||
return self.base_dataloader.__class__
|
||||
@ -565,7 +566,8 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
try:
|
||||
current_batch = next(dataloader_iter)
|
||||
except StopIteration:
|
||||
yield
|
||||
self.end()
|
||||
return
|
||||
|
||||
batch_index = 0
|
||||
while True:
|
||||
@ -761,12 +763,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
|
||||
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
|
||||
# device mesh may hold any number of dimensions, however,
|
||||
# below code is for targetted support for dp, fsdp and tp
|
||||
# below code is for targeted support for dp, fsdp and tp
|
||||
|
||||
# device mesh will be used only if there is tp involved
|
||||
# or any multi-dimensional parallelism involving tp
|
||||
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
|
||||
# otherwise the default behavour not using device mesh should be sufficient
|
||||
# otherwise the default behaviour not using device mesh should be sufficient
|
||||
# since multi dimensional parallelism devoid of tp would anyway need
|
||||
# different batches for each process irrespective of dp or fsdp
|
||||
self.submesh_tp = None
|
||||
@ -1061,7 +1063,7 @@ def prepare_data_loader(
|
||||
ignored otherwise.
|
||||
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
|
||||
reproducability. Comes at a cost of potentially different performances due to different shuffling
|
||||
reproducibility. Comes at a cost of potentially different performances due to different shuffling
|
||||
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
|
||||
`self.set_epoch`
|
||||
data_seed (`int`, *optional*, defaults to `None`):
|
||||
@ -1111,32 +1113,34 @@ def prepare_data_loader(
|
||||
# Given a device mesh (dp, tp) = (2, 3):
|
||||
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
|
||||
# - Processes with the same DP rank will receive the same batch.
|
||||
submesh_tp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
process_index = process_index // submesh_tp_size
|
||||
num_processes = num_processes // submesh_tp_size
|
||||
else:
|
||||
# when device mesh is used, specifically with TP or CP
|
||||
# when device mesh is used, specifically with TP
|
||||
# then there is need to update process_index and num_processes
|
||||
# to bring in the effect of generating same batch across TP/CP ranks
|
||||
# to bring in the effect of generating same batch across TP ranks
|
||||
# and different batch across FSDP and DP ranks.
|
||||
# Example:
|
||||
# if device mesh is (dp,fsdp,tp,cp) = (2, 2, 2, 3)
|
||||
# ranks would range from 0...23
|
||||
# from data angle ranks should look like 0 0 0 0 0 0 1 1 1 1 1 1 ...
|
||||
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
|
||||
# ranks would range from 0...11
|
||||
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
|
||||
# processes with same ranks/ids would receive the same batch
|
||||
# for CP the same as TP applies
|
||||
submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
submesh_cp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
|
||||
if "dp_shard" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
|
||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)
|
||||
num_processes = submesh_fsdp_size * submesh_dp_size
|
||||
|
||||
@ -1197,7 +1201,16 @@ def prepare_data_loader(
|
||||
dataloader.sampler.generator = generator
|
||||
# No change if no multiprocess
|
||||
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
|
||||
if isinstance(new_dataset, IterableDataset):
|
||||
if is_datasets_available():
|
||||
from datasets import IterableDataset as DatasetsIterableDataset
|
||||
if (
|
||||
is_datasets_available()
|
||||
and isinstance(new_dataset, DatasetsIterableDataset)
|
||||
and not split_batches
|
||||
and new_dataset.n_shards > num_processes
|
||||
):
|
||||
new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
|
||||
elif isinstance(new_dataset, IterableDataset):
|
||||
if getattr(dataloader.dataset, "generator", None) is not None:
|
||||
synchronized_generator = dataloader.dataset.generator
|
||||
new_dataset = IterableDatasetShard(
|
||||
|
||||
189
src/accelerate/dist_checkpointing.py
Normal file
189
src/accelerate/dist_checkpointing.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.import queue
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
from io import UnsupportedOperation
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.distributed.checkpoint.state_dict as dcs
|
||||
from torch.distributed.checkpoint.filesystem import (
|
||||
FileSystemWriter,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
_generate_uuid,
|
||||
_split_by_size_and_type,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
class AccelerateStorageWriter(FileSystemWriter):
|
||||
_DEFAULT_SUFFIX = ".distcp"
|
||||
_OPTIM_FILE_PATH = "optimizer_0"
|
||||
_MODEL_FILE_PATH = "pytorch_model_fsdp_0"
|
||||
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH)
|
||||
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH)
|
||||
self.fs.mkdir(self.optim_path)
|
||||
self.fs.mkdir(self.model_path)
|
||||
return super().prepare_local_plan(plan)
|
||||
|
||||
def write_data(
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner,
|
||||
):
|
||||
storage_plan = plan.storage_data
|
||||
optim_file_count = 0
|
||||
model_file_count = 0
|
||||
|
||||
def gen_file(is_optimizer: bool = False) -> str:
|
||||
nonlocal optim_file_count, model_file_count
|
||||
if is_optimizer:
|
||||
optim_file_count += 1
|
||||
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}"
|
||||
else:
|
||||
model_file_count += 1
|
||||
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}"
|
||||
|
||||
file_queue: queue.Queue = queue.Queue()
|
||||
|
||||
for bucket in _split_by_size_and_type(1, plan.items):
|
||||
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn]
|
||||
model_states = [wi for wi in bucket if "model" in wi.index.fqn]
|
||||
|
||||
for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]):
|
||||
file_name = gen_file()
|
||||
path = self.fs.concat_path(path, file_name)
|
||||
file_queue.put((path, file_name, state))
|
||||
|
||||
return self._write_data(planner, file_queue)
|
||||
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
try:
|
||||
metadata = dataclasses.replace(metadata, version="1.0.0")
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def _split_metadata(
|
||||
metadata: Metadata,
|
||||
) -> tuple[Metadata, Metadata]:
|
||||
result = []
|
||||
for to_get in ["model", "optim"]:
|
||||
result.append(
|
||||
Metadata(
|
||||
state_dict_metadata={
|
||||
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k
|
||||
},
|
||||
planner_data={
|
||||
k.removeprefix("state."): tuple([x for x in v if x != "state"])
|
||||
for k, v in metadata.planner_data.items()
|
||||
if to_get in k
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(result)
|
||||
|
||||
model_metadata, optim_metadata = _split_metadata(metadata)
|
||||
model_storage_md, optim_storage_md = {}, {}
|
||||
for wr_list in results:
|
||||
for wr in wr_list:
|
||||
new_index = dataclasses.asdict(wr.index)
|
||||
new_index["fqn"] = new_index["fqn"].removeprefix("state.")
|
||||
wr = WriteResult(
|
||||
index=MetadataIndex(**new_index),
|
||||
size_in_bytes=wr.size_in_bytes,
|
||||
storage_data=wr.storage_data,
|
||||
)
|
||||
if "optim" in wr.index.fqn:
|
||||
optim_storage_md.update({wr.index: wr.storage_data})
|
||||
else:
|
||||
model_storage_md.update({wr.index: wr.storage_data})
|
||||
|
||||
model_metadata.storage_data = model_storage_md
|
||||
optim_metadata.storage_data = optim_storage_md
|
||||
|
||||
model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid())
|
||||
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid())
|
||||
|
||||
tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp"))
|
||||
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp"))
|
||||
|
||||
for meta, tmp_path, final_path in zip(
|
||||
[model_metadata, optim_metadata],
|
||||
[tmp_model_path, tmp_optim_path],
|
||||
[self.model_path, self.optim_path],
|
||||
):
|
||||
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
|
||||
pickle.dump(meta, metadata_file)
|
||||
if self.sync_files:
|
||||
try:
|
||||
os.fsync(metadata_file.fileno())
|
||||
except (AttributeError, UnsupportedOperation):
|
||||
os.sync()
|
||||
|
||||
metadata_path = self.fs.concat_path(final_path, ".metadata")
|
||||
if self.fs.exists(metadata_path):
|
||||
self.fs.rm_file(metadata_path)
|
||||
|
||||
self.fs.rename(tmp_path, metadata_path)
|
||||
|
||||
|
||||
def save_model_and_optimizer(
|
||||
accelerator: "Accelerator",
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
save_path: str,
|
||||
async_save: bool = False,
|
||||
) -> None:
|
||||
# async_save = False
|
||||
if getattr(accelerator, "_async_save_handle", None) is not None:
|
||||
accelerator._async_save_handle.result()
|
||||
|
||||
options = dcs.StateDictOptions()
|
||||
|
||||
import time
|
||||
|
||||
accelerator.print(f"{time.asctime()} - Preparing state dict...")
|
||||
model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options)
|
||||
accelerator.print(f"{time.asctime()} - Prepared state dict...")
|
||||
|
||||
accelerator.print(f"{time.asctime()} - Saving state dict...")
|
||||
stateful = {
|
||||
"model": model_sd,
|
||||
"optimizer": optimizer_sd,
|
||||
}
|
||||
|
||||
save_fn = dcp.save if not async_save else dcp.async_save
|
||||
|
||||
potential_handle = dcp.async_save(
|
||||
state_dict=stateful,
|
||||
storage_writer=AccelerateStorageWriter(save_path),
|
||||
)
|
||||
accelerator.print(f"{time.asctime()} - Finished saving state dict...")
|
||||
|
||||
if async_save:
|
||||
accelerator._async_save_handle = potential_handle
|
||||
@ -714,9 +714,20 @@ class CpuOffload(ModelHook):
|
||||
return module.to("cpu")
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
if self.prev_module_hook is not None:
|
||||
self.prev_module_hook.offload()
|
||||
clear_device_cache()
|
||||
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
|
||||
prev_module = self.prev_module_hook.model
|
||||
prev_device = next(prev_module.parameters()).device
|
||||
|
||||
# Only offload the previous module if it is not already on CPU.
|
||||
if prev_device != torch.device("cpu"):
|
||||
self.prev_module_hook.offload()
|
||||
clear_device_cache()
|
||||
|
||||
# If the current device is already the self.execution_device, we can skip the transfer.
|
||||
current_device = next(module.parameters()).device
|
||||
if current_device == self.execution_device:
|
||||
return args, kwargs
|
||||
|
||||
module.to(self.execution_device)
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
|
||||
@ -60,8 +60,8 @@ def notebook_launcher(
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
|
||||
any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
|
||||
To use this function absolutely zero calls to a device must be made in the notebook session before calling. If any
|
||||
have been made, you will need to restart the notebook and make sure no cells use any device capability.
|
||||
|
||||
Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
|
||||
of those calls have been made.
|
||||
@ -76,11 +76,11 @@ def notebook_launcher(
|
||||
Tuple of arguments to pass to the function (it will receive `*args`).
|
||||
num_processes (`int`, *optional*):
|
||||
The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
|
||||
the number of GPUs available otherwise.
|
||||
the number of devices available otherwise.
|
||||
mixed_precision (`str`, *optional*, defaults to `"no"`):
|
||||
If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
|
||||
If `fp16` or `bf16`, will use mixed precision training on multi-device.
|
||||
use_port (`str`, *optional*, defaults to `"29500"`):
|
||||
The port to use to communicate between processes when launching a multi-GPU training.
|
||||
The port to use to communicate between processes when launching a multi-device training.
|
||||
master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
|
||||
The address to use for communication between processes.
|
||||
node_rank (`int`, *optional*, defaults to 0):
|
||||
@ -105,7 +105,7 @@ def notebook_launcher(
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Assume this is defined in a Jupyter Notebook on an instance with two GPUs
|
||||
# Assume this is defined in a Jupyter Notebook on an instance with two devices
|
||||
from accelerate import notebook_launcher
|
||||
|
||||
|
||||
@ -158,27 +158,27 @@ def notebook_launcher(
|
||||
else:
|
||||
if num_processes is None:
|
||||
raise ValueError(
|
||||
"You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
|
||||
"You have to specify the number of devices you would like to use, add `num_processes=...` to your call."
|
||||
)
|
||||
if node_rank >= num_nodes:
|
||||
raise ValueError("The node_rank must be less than the number of nodes.")
|
||||
if num_processes > 1:
|
||||
# Multi-GPU launch
|
||||
# Multi-device launch
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from torch.multiprocessing import start_processes
|
||||
from torch.multiprocessing.spawn import ProcessRaisedException
|
||||
|
||||
if len(AcceleratorState._shared_state) > 0:
|
||||
raise ValueError(
|
||||
"To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
|
||||
"To launch a multi-device training from your notebook, the `Accelerator` should only be initialized "
|
||||
"inside your training function. Restart your notebook and make sure no cells initializes an "
|
||||
"`Accelerator`."
|
||||
)
|
||||
# Check for specific libraries known to initialize CUDA that users constantly use
|
||||
# Check for specific libraries known to initialize device that users constantly use
|
||||
problematic_imports = are_libraries_initialized("bitsandbytes")
|
||||
if len(problematic_imports) > 0:
|
||||
err = (
|
||||
"Could not start distributed process. Libraries known to initialize CUDA upon import have been "
|
||||
"Could not start distributed process. Libraries known to initialize device upon import have been "
|
||||
"imported already. Please keep these imports inside your training function to try and help with this:"
|
||||
)
|
||||
for lib_name in problematic_imports:
|
||||
@ -203,24 +203,26 @@ def notebook_launcher(
|
||||
# process here (the other ones will be set be the launcher).
|
||||
with patch_environment(**patched_env):
|
||||
# First dummy launch
|
||||
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
distributed_type = "MULTI_XPU" if device_type == "xpu" else "MULTI_GPU"
|
||||
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
|
||||
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
|
||||
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
|
||||
try:
|
||||
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
|
||||
except ProcessRaisedException as e:
|
||||
err = "An issue was found when verifying a stable environment for the notebook launcher."
|
||||
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
||||
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
|
||||
raise RuntimeError(
|
||||
f"{err}"
|
||||
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
||||
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
||||
"which one is problematic and causing CUDA to be initialized."
|
||||
f"which one is problematic and causing {device_type.upper()} to be initialized."
|
||||
) from e
|
||||
else:
|
||||
raise RuntimeError(f"{err} The following error was raised: {e}") from e
|
||||
# Now the actual launch
|
||||
launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
|
||||
print(f"Launching training on {num_processes} GPUs.")
|
||||
launcher = PrepareForLaunch(function, distributed_type=distributed_type)
|
||||
print(f"Launching training on {num_processes} {device_type.upper()}s.")
|
||||
try:
|
||||
if rdzv_conf is None:
|
||||
rdzv_conf = {}
|
||||
@ -244,23 +246,25 @@ def notebook_launcher(
|
||||
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
|
||||
elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
|
||||
except ProcessRaisedException as e:
|
||||
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
|
||||
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
|
||||
raise RuntimeError(
|
||||
"CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
|
||||
f"{device_type.upper()} has been initialized before the `notebook_launcher` could create a forked subprocess. "
|
||||
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
|
||||
"Please review your imports and test them when running the `notebook_launcher()` to identify "
|
||||
"which one is problematic and causing CUDA to be initialized."
|
||||
f"which one is problematic and causing {device_type.upper()} to be initialized."
|
||||
) from e
|
||||
else:
|
||||
raise RuntimeError(f"An issue was found when launching the training: {e}") from e
|
||||
|
||||
else:
|
||||
# No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
|
||||
# No need for a distributed launch otherwise as it's either CPU, GPU, XPU or MPS.
|
||||
if is_mps_available():
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
print("Launching training on MPS.")
|
||||
elif torch.cuda.is_available():
|
||||
print("Launching training on one GPU.")
|
||||
elif torch.xpu.is_available():
|
||||
print("Launching training on one XPU.")
|
||||
else:
|
||||
print("Launching training on CPU.")
|
||||
function(*args)
|
||||
|
||||
322
src/accelerate/parallelism_config.py
Normal file
322
src/accelerate/parallelism_config.py
Normal file
@ -0,0 +1,322 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelismConfig:
|
||||
"""
|
||||
A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
|
||||
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
|
||||
|
||||
Args:
|
||||
dp_replicate_size (`int`, defaults to `1`):
|
||||
The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
|
||||
group will not be used.
|
||||
dp_shard_size (`int`, defaults to `1`):
|
||||
The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
|
||||
be greater than 1, as composing DDP + TP is currently not supported.
|
||||
tp_size (`int`, defaults to `1`):
|
||||
The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
|
||||
used.
|
||||
cp_size (`int`, defaults to `1`):
|
||||
The size of the context parallel group. Currently not supported, but reserved for future use and enabled
|
||||
for downstream libraries.
|
||||
tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
|
||||
The handler for the tensor parallel group.
|
||||
|
||||
You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
|
||||
together:
|
||||
- `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
|
||||
- `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
|
||||
- `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
|
||||
`DistributedDataParallelKwargs` instead.
|
||||
|
||||
"""
|
||||
|
||||
dp_replicate_size: int = None
|
||||
dp_shard_size: int = None
|
||||
tp_size: int = None
|
||||
cp_size: int = None
|
||||
|
||||
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
|
||||
tp_handler: Union[None, TorchTensorParallelConfig] = None
|
||||
cp_handler: Union[None, TorchContextParallelConfig] = None
|
||||
|
||||
device_mesh = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"ParallelismConfig(\n "
|
||||
f"\tdp_replicate_size={self.dp_replicate_size},\n"
|
||||
f"\tdp_shard_size={self.dp_shard_size},\n"
|
||||
f"\ttp_size={self.tp_size},\n"
|
||||
f"\tcp_size={self.cp_size},\n"
|
||||
f"\ttotal_size={self.total_size}\n"
|
||||
f"\ttp_handler={self.tp_handler},\n"
|
||||
f"\tcp_handler={self.cp_handler})\n"
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
import copy
|
||||
|
||||
_non_serializable_fields = ["device_mesh"]
|
||||
|
||||
copy.deepcopy(
|
||||
{
|
||||
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in _non_serializable_fields
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def dp_dim_names(self):
|
||||
"""Names of enabled dimensions across which data parallelism is applied."""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def non_dp_dim_names(self):
|
||||
"""Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
|
||||
dims = []
|
||||
if self.tp_enabled:
|
||||
dims += ["tp"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def dp_shard_cp_dim_names(self):
|
||||
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
|
||||
dims = []
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def dp_cp_dim_names(self):
|
||||
"""Names of enabled dimensions across which loss should be averaged"""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def fsdp_dim_names(self):
|
||||
"""Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
dims += ["dp_shard_cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def total_size(self):
|
||||
"""The total size of the parallelism configuration, which is the product of all sizes."""
|
||||
return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
|
||||
|
||||
@property
|
||||
def non_data_parallel_size(self):
|
||||
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
|
||||
return self.tp_size * self.cp_size
|
||||
|
||||
@property
|
||||
def data_parallel_size(self):
|
||||
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
|
||||
return self.dp_replicate_size * self.dp_shard_size
|
||||
|
||||
@property
|
||||
def dp_replicate_enabled(self):
|
||||
"""True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
|
||||
return self.dp_replicate_size > 1
|
||||
|
||||
@property
|
||||
def dp_shard_enabled(self):
|
||||
"""True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
|
||||
return self.dp_shard_size > 1
|
||||
|
||||
@property
|
||||
def tp_enabled(self):
|
||||
"""True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
|
||||
return self.tp_size > 1
|
||||
|
||||
@property
|
||||
def cp_enabled(self):
|
||||
"""True if context parallelism is enabled, i.e. `cp_size > 1`."""
|
||||
return self.cp_size > 1
|
||||
|
||||
@property
|
||||
def active_mesh_dims(self):
|
||||
"""Names of all active mesh dimensions."""
|
||||
return self.dp_dim_names + self.non_dp_dim_names
|
||||
|
||||
def build_device_mesh(self, device_type: str):
|
||||
"""Builds a device mesh for the given device type based on the parallelism configuration.
|
||||
This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
|
||||
|
||||
Args:
|
||||
device_type (`str`): The type of device for which to build the mesh, e
|
||||
"""
|
||||
if is_torch_version(">=", "2.2.0"):
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
else:
|
||||
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
|
||||
|
||||
mesh = self._get_mesh()
|
||||
if len(mesh) == 0:
|
||||
return None
|
||||
mesh_dim_names, mesh_shape = mesh
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
if self.dp_dim_names:
|
||||
device_mesh[self.dp_dim_names]._flatten("dp")
|
||||
if self.dp_shard_cp_dim_names:
|
||||
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
|
||||
if self.dp_cp_dim_names:
|
||||
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
|
||||
|
||||
return device_mesh
|
||||
|
||||
def get_device_mesh(self, device_type: Optional[str] = None):
|
||||
if self.device_mesh is None:
|
||||
if device_type is not None:
|
||||
self.device_mesh = self.build_device_mesh(device_type)
|
||||
else:
|
||||
raise ("You need to pass a device_type e.g cuda to build the device mesh")
|
||||
else:
|
||||
if device_type is not None:
|
||||
if self.device_mesh.device_type != device_type:
|
||||
raise ValueError(
|
||||
f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
|
||||
)
|
||||
return self.device_mesh
|
||||
|
||||
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
||||
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
|
||||
|
||||
# Build mesh dimensions dictionary
|
||||
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
|
||||
|
||||
# Apply canonical ordering
|
||||
mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
|
||||
sorted_items = sorted(
|
||||
mesh_dims.items(),
|
||||
key=lambda x: (mesh_order.index(x[0])),
|
||||
)
|
||||
return tuple(zip(*sorted_items))
|
||||
|
||||
def __post_init__(self):
|
||||
# Basic size validation
|
||||
if self.dp_replicate_size is None:
|
||||
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
|
||||
if self.dp_shard_size is None:
|
||||
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
|
||||
if self.tp_size is None:
|
||||
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
|
||||
if self.cp_size is None:
|
||||
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.tp_handler is None:
|
||||
self.tp_handler = TorchTensorParallelConfig()
|
||||
|
||||
if self.cp_size > 1:
|
||||
if self.cp_handler is None:
|
||||
self.cp_handler = TorchContextParallelConfig()
|
||||
|
||||
if self.dp_replicate_size < 1:
|
||||
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
|
||||
if self.dp_shard_size < 1:
|
||||
raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
|
||||
if self.tp_size < 1:
|
||||
raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
|
||||
if self.cp_size < 1:
|
||||
raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
|
||||
|
||||
if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
|
||||
raise ValueError(
|
||||
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
|
||||
"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
|
||||
"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
|
||||
)
|
||||
self._sizes = {
|
||||
"dp_replicate": self.dp_replicate_size,
|
||||
"dp_shard": self.dp_shard_size,
|
||||
"tp": self.tp_size,
|
||||
"cp": self.cp_size,
|
||||
}
|
||||
|
||||
def _set_size(self, parallelism: str, size: int):
|
||||
assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
|
||||
self._sizes[parallelism] = size
|
||||
setattr(self, f"{parallelism}_size", size)
|
||||
|
||||
def _validate_accelerator(self, accelerator: "Accelerator"):
|
||||
_warnings = set()
|
||||
if not accelerator.multi_device and self.total_size == 1:
|
||||
# No distributed setup, valid parallelism config
|
||||
return
|
||||
|
||||
# We need this to ensure DDP works
|
||||
if self.total_size == 1:
|
||||
self._set_size("dp_replicate", accelerator.num_processes)
|
||||
|
||||
if self.total_size != accelerator.num_processes:
|
||||
raise ValueError(
|
||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||
f"dp_shard_size/tp_size/cp_size."
|
||||
)
|
||||
|
||||
if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
|
||||
raise ValueError(
|
||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||
)
|
||||
|
||||
for parallelism, size in self._sizes.items():
|
||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||
_warnings.add(
|
||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||
)
|
||||
|
||||
if _warnings and accelerator.is_main_process:
|
||||
warnings.warn(
|
||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||
UserWarning,
|
||||
)
|
||||
@ -132,7 +132,7 @@ class PartialState:
|
||||
Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
|
||||
`True` and force the execution on the CPU.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
|
||||
Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be
|
||||
found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
|
||||
|
||||
**Available attributes:**
|
||||
@ -187,7 +187,7 @@ class PartialState:
|
||||
dist_information = None
|
||||
if use_sagemaker_dp is None:
|
||||
use_sagemaker_dp = (
|
||||
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
|
||||
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true"
|
||||
and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
|
||||
)
|
||||
|
||||
@ -195,14 +195,14 @@ class PartialState:
|
||||
original_backend = kwargs.pop("backend", None)
|
||||
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
|
||||
if original_backend is not None and backend != original_backend:
|
||||
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
|
||||
raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
|
||||
self.backend = backend
|
||||
self.distributed_type = distributed_type
|
||||
use_deepspeed = False
|
||||
if not cpu and self.backend != "xla":
|
||||
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
|
||||
# Deal with spawning deepspeed
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
|
||||
if not is_deepspeed_available():
|
||||
raise ImportError(
|
||||
"DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
|
||||
@ -213,12 +213,6 @@ class PartialState:
|
||||
if self.backend == "tccl":
|
||||
local_rank = os.environ.get("LOCAL_RANK", -1)
|
||||
torch.sdaa.set_device(f"sdaa:{local_rank}")
|
||||
if (
|
||||
self.backend == "nccl"
|
||||
and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
|
||||
and os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true"
|
||||
):
|
||||
self.backend = "cuda:nccl,cpu:gloo"
|
||||
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
|
||||
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
|
||||
use_deepspeed = True
|
||||
@ -230,6 +224,16 @@ class PartialState:
|
||||
if self.backend == "tccl":
|
||||
local_rank = os.environ.get("LOCAL_RANK", -1)
|
||||
torch.sdaa.set_device(f"sdaa:{local_rank}")
|
||||
if (
|
||||
self.backend == "nccl"
|
||||
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
|
||||
and (
|
||||
os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
|
||||
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
|
||||
or True
|
||||
)
|
||||
):
|
||||
self.backend = "cuda:nccl,cpu:gloo"
|
||||
torch.distributed.init_process_group(backend=self.backend, **kwargs)
|
||||
|
||||
# XPU and CPU require special env configs to be set
|
||||
@ -397,7 +401,7 @@ class PartialState:
|
||||
DistributedType.DEEPSPEED,
|
||||
DistributedType.FSDP,
|
||||
):
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.barrier(device_ids=[self.local_process_index])
|
||||
elif self.distributed_type == DistributedType.XLA:
|
||||
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
||||
|
||||
@ -866,6 +870,8 @@ class AcceleratorState:
|
||||
- **device** (`torch.device`) -- The device to use.
|
||||
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
||||
in use.
|
||||
- **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
|
||||
current training environment. This is used to configure the distributed training environment.
|
||||
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
|
||||
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
||||
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
||||
@ -896,6 +902,7 @@ class AcceleratorState:
|
||||
fsdp_plugin=None,
|
||||
torch_tp_plugin=None,
|
||||
megatron_lm_plugin=None,
|
||||
parallelism_config=None,
|
||||
_from_accelerator: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -910,6 +917,8 @@ class AcceleratorState:
|
||||
self.deepspeed_plugins = None
|
||||
self.use_ipex = None
|
||||
self.torch_tp_plugin = torch_tp_plugin
|
||||
self.parallelism_config = parallelism_config
|
||||
self.device_mesh = None
|
||||
mixed_precision = (
|
||||
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
|
||||
if mixed_precision is None
|
||||
@ -941,8 +950,13 @@ class AcceleratorState:
|
||||
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
|
||||
"before using any functionality from the `accelerate` library."
|
||||
)
|
||||
# deepspeed handles mixed_precision using deepspeed_config
|
||||
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
|
||||
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
|
||||
# if we're using fp8.
|
||||
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
|
||||
self._mixed_precision = "no"
|
||||
else:
|
||||
self._mixed_precision = mixed_precision
|
||||
|
||||
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
|
||||
if mixed_precision == "bf16":
|
||||
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
|
||||
@ -953,7 +967,7 @@ class AcceleratorState:
|
||||
os.environ["XLA_USE_BF16"] = str(1)
|
||||
os.environ["XLA_DOWNCAST_BF16"] = str(0)
|
||||
self.downcast_bfloat = False
|
||||
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
|
||||
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu:
|
||||
self.distributed_type = DistributedType.DEEPSPEED
|
||||
if not isinstance(deepspeed_plugin, dict):
|
||||
deepspeed_plugin.set_mixed_precision(mixed_precision)
|
||||
@ -974,19 +988,35 @@ class AcceleratorState:
|
||||
DistributedType.MULTI_XPU,
|
||||
DistributedType.MULTI_HPU,
|
||||
]:
|
||||
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
|
||||
# TODO: Siro - remove when axolotl fixes their side
|
||||
if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
|
||||
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
|
||||
raise ValueError(
|
||||
"`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism, as we also shard the model across the device mesh to save more memory"
|
||||
)
|
||||
if (
|
||||
self.parallelism_config is not None
|
||||
and self.parallelism_config.cp_enabled
|
||||
and fsdp_plugin.fsdp_version == 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
|
||||
)
|
||||
if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
|
||||
self.parallelism_config is not None and self.parallelism_config.cp_enabled
|
||||
):
|
||||
self.distributed_type = DistributedType.FSDP
|
||||
if self._mixed_precision != "no":
|
||||
if self._mixed_precision != "no" and fsdp_plugin is not None:
|
||||
fsdp_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.fsdp_plugin = fsdp_plugin
|
||||
if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
|
||||
if os.environ.get(
|
||||
"ACCELERATE_USE_MEGATRON_LM", "false"
|
||||
).lower() == "true" and self.distributed_type not in [
|
||||
DistributedType.MULTI_XPU,
|
||||
]:
|
||||
self.distributed_type = DistributedType.MEGATRON_LM
|
||||
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.megatron_lm_plugin = megatron_lm_plugin
|
||||
if self.torch_tp_plugin is not None:
|
||||
self.distributed_type = DistributedType.TP
|
||||
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
|
||||
if is_ipex_available():
|
||||
# check if user disables it explicitly
|
||||
@ -1032,7 +1062,7 @@ class AcceleratorState:
|
||||
|
||||
@property
|
||||
def mixed_precision(self):
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
|
||||
config = self.deepspeed_plugin.deepspeed_config
|
||||
if config.get("fp16", {}).get("enabled", False):
|
||||
mixed_precision = "fp16"
|
||||
@ -1055,7 +1085,7 @@ class AcceleratorState:
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
|
||||
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
||||
If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
|
||||
"""
|
||||
PartialState().destroy_process_group(group)
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ from .testing import (
|
||||
require_torchvision,
|
||||
require_tpu,
|
||||
require_transformer_engine,
|
||||
require_transformer_engine_mxfp8,
|
||||
require_xpu,
|
||||
run_first,
|
||||
skip,
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator()
|
||||
B, S, D = 2, 3, 4
|
||||
rank_data = torch.ones((B, S, D), device="cuda") * (accelerator.process_index + 1)
|
||||
all_rank_data = [torch.empty_like(rank_data) for _ in range(accelerator.num_processes)]
|
||||
torch.distributed.all_gather(all_rank_data, rank_data)
|
||||
|
||||
dataloader = DataLoader(all_rank_data, batch_size=B, shuffle=False)
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
for batch in dataloader:
|
||||
all_rank_batch = [torch.empty_like(batch) for _ in range(accelerator.num_processes)]
|
||||
torch.distributed.all_gather(all_rank_batch, batch)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
for rank_idx in range(accelerator.num_processes):
|
||||
torch.testing.assert_close(
|
||||
all_rank_batch[0],
|
||||
all_rank_batch[rank_idx],
|
||||
msg=f"Rank {rank_idx} batch {all_rank_batch[rank_idx]} differs from rank 0 batch {all_rank_batch[0]}",
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -34,8 +34,7 @@ from accelerate.state import AcceleratorState
|
||||
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
|
||||
|
||||
|
||||
MAX_GPU_BATCH_SIZE = 16
|
||||
EVAL_BATCH_SIZE = 32
|
||||
EVAL_BATCH_SIZE = 16
|
||||
|
||||
|
||||
class NoiseModel(torch.nn.Module):
|
||||
@ -318,11 +317,11 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
default=2,
|
||||
default=3,
|
||||
help="Number of train epochs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
|
||||
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 8}
|
||||
single_model_training(config, args)
|
||||
AcceleratorState._reset_state(True)
|
||||
multiple_model_training(config, args)
|
||||
|
||||
@ -69,7 +69,7 @@ class TorchTracemalloc:
|
||||
self.begin = torch.npu.memory_allocated()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||
torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
|
||||
self.begin = torch.xpu.memory_allocated()
|
||||
elif is_hpu_available():
|
||||
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
|
||||
|
||||
@ -25,7 +25,8 @@ from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
|
||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||
|
||||
|
||||
@ -83,7 +84,7 @@ def training_function(config, args):
|
||||
accelerator_kwargs = {}
|
||||
# need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
|
||||
if args.tp_size is not None:
|
||||
accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)
|
||||
accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
@ -79,10 +79,6 @@ def mock_training(accelerator, model):
|
||||
|
||||
def check_weights(operation, state_1, state_2):
|
||||
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
|
||||
if str(weight_1.device) != torch_device:
|
||||
weight_1 = weight_1.to(torch_device)
|
||||
if str(weight_2.device) != torch_device:
|
||||
weight_2 = weight_2.to(torch_device)
|
||||
if operation == "same":
|
||||
assert torch.allclose(weight_1, weight_2)
|
||||
else:
|
||||
@ -91,7 +87,7 @@ def check_weights(operation, state_1, state_2):
|
||||
|
||||
def check_safetensors_weights(path, model):
|
||||
safe_state_dict = load_file(path / "model.safetensors")
|
||||
safe_loaded_model = TinyModel()
|
||||
safe_loaded_model = TinyModel().to(torch_device)
|
||||
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
|
||||
safe_loaded_model.load_state_dict(safe_state_dict)
|
||||
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
|
||||
@ -99,7 +95,7 @@ def check_safetensors_weights(path, model):
|
||||
|
||||
def check_pytorch_weights(path, model):
|
||||
nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
|
||||
nonsafe_loaded_model = TinyModel()
|
||||
nonsafe_loaded_model = TinyModel().to(torch_device)
|
||||
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
|
||||
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
|
||||
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())
|
||||
|
||||
@ -50,7 +50,7 @@ def test_gather_object(state):
|
||||
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
||||
|
||||
|
||||
def test_gather_non_contigous(state):
|
||||
def test_gather_non_contiguous(state):
|
||||
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
return
|
||||
@ -160,8 +160,8 @@ def main():
|
||||
test_gather(state)
|
||||
state.print("testing gather_object")
|
||||
test_gather_object(state)
|
||||
state.print("testing gather non-contigous")
|
||||
test_gather_non_contigous(state)
|
||||
state.print("testing gather non-contiguous")
|
||||
test_gather_non_contiguous(state)
|
||||
state.print("testing broadcast")
|
||||
test_broadcast(state)
|
||||
state.print("testing pad_across_processes")
|
||||
|
||||
@ -35,10 +35,12 @@ from accelerate.utils import (
|
||||
gather,
|
||||
gather_object,
|
||||
is_bf16_available,
|
||||
is_cuda_available,
|
||||
is_datasets_available,
|
||||
is_fp16_available,
|
||||
is_hpu_available,
|
||||
is_ipex_available,
|
||||
is_mps_available,
|
||||
is_pytest_available,
|
||||
is_xpu_available,
|
||||
set_seed,
|
||||
@ -534,7 +536,7 @@ def training_check(use_seedable_sampler=False):
|
||||
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
|
||||
|
||||
# FP32 wrapper check
|
||||
if torch.cuda.is_available():
|
||||
if is_cuda_available() or is_mps_available():
|
||||
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
|
||||
print("Keep fp32 wrapper check.")
|
||||
AcceleratorState._reset_state()
|
||||
@ -625,7 +627,7 @@ def training_check(use_seedable_sampler=False):
|
||||
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
|
||||
)
|
||||
|
||||
# IPEX support is only for CPU
|
||||
# IPEX CPU tests
|
||||
if is_ipex_available():
|
||||
print("ipex BF16 training check.")
|
||||
AcceleratorState._reset_state()
|
||||
|
||||
@ -61,6 +61,7 @@ from ..utils import (
|
||||
is_pytest_available,
|
||||
is_schedulefree_available,
|
||||
is_sdaa_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_timm_available,
|
||||
is_torch_version,
|
||||
@ -68,7 +69,9 @@ from ..utils import (
|
||||
is_torchao_available,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
is_torchvision_available,
|
||||
is_trackio_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformer_engine_mxfp8_available,
|
||||
is_transformers_available,
|
||||
is_triton_available,
|
||||
is_wandb_available,
|
||||
@ -249,6 +252,10 @@ def require_fp8(test_case):
|
||||
return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
|
||||
|
||||
|
||||
def require_fsdp2(test_case):
|
||||
return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
|
||||
|
||||
|
||||
def require_mlu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
|
||||
@ -454,6 +461,13 @@ def require_wandb(test_case):
|
||||
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
|
||||
|
||||
|
||||
def require_trackio(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
|
||||
|
||||
|
||||
def require_comet_ml(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
|
||||
@ -482,6 +496,13 @@ def require_dvclive(test_case):
|
||||
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
|
||||
|
||||
|
||||
def require_swanlab(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
|
||||
|
||||
|
||||
def require_pandas(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
|
||||
@ -520,6 +541,16 @@ def require_transformer_engine(test_case):
|
||||
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
|
||||
|
||||
|
||||
def require_transformer_engine_mxfp8(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
|
||||
when transformers engine MXFP8 block scaling isn't available
|
||||
"""
|
||||
return unittest.skipUnless(
|
||||
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
|
||||
@ -536,7 +567,8 @@ def require_matplotlib(test_case):
|
||||
|
||||
|
||||
_atleast_one_tracker_available = (
|
||||
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
|
||||
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
|
||||
and not is_comet_ml_available()
|
||||
)
|
||||
|
||||
|
||||
@ -566,7 +598,7 @@ def require_torchdata_stateful_dataloader(test_case):
|
||||
def run_first(test_case):
|
||||
"""
|
||||
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
|
||||
garanteed to run first.
|
||||
guaranteed to run first.
|
||||
|
||||
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
|
||||
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
|
||||
@ -585,7 +617,7 @@ def run_first(test_case):
|
||||
class TempDirTestCase(unittest.TestCase):
|
||||
"""
|
||||
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
||||
data at the start of a test, and then destroyes it at the end of the TestCase.
|
||||
data at the start of a test, and then destroys it at the end of the TestCase.
|
||||
|
||||
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
||||
|
||||
|
||||
@ -34,7 +34,9 @@ from .utils import (
|
||||
is_comet_ml_available,
|
||||
is_dvclive_available,
|
||||
is_mlflow_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_trackio_available,
|
||||
is_wandb_available,
|
||||
listify,
|
||||
)
|
||||
@ -63,6 +65,12 @@ if is_clearml_available():
|
||||
if is_dvclive_available():
|
||||
_available_trackers.append(LoggerType.DVCLIVE)
|
||||
|
||||
if is_swanlab_available():
|
||||
_available_trackers.append(LoggerType.SWANLAB)
|
||||
|
||||
if is_trackio_available():
|
||||
_available_trackers.append(LoggerType.TRACKIO)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -103,7 +111,7 @@ class GeneralTracker:
|
||||
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
|
||||
tracking mechanism used by a tracker class (such as the `run` for wandb)
|
||||
|
||||
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
|
||||
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and
|
||||
other functions should occur on the main process or across all processes (by default will use `True`)
|
||||
"""
|
||||
|
||||
@ -133,7 +141,7 @@ class GeneralTracker:
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Lazy initialization of the tracker inside Accelerator to avoid initalizing PartialState before
|
||||
Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before
|
||||
InitProcessGroupKwargs.
|
||||
"""
|
||||
pass
|
||||
@ -332,7 +340,16 @@ class WandBTracker(GeneralTracker):
|
||||
"""
|
||||
import wandb
|
||||
|
||||
wandb.config.update(values, allow_val_change=True)
|
||||
if os.environ.get("WANDB_MODE") == "offline":
|
||||
# In offline mode, restart wandb with config included
|
||||
if hasattr(self, "run") and self.run:
|
||||
self.run.finish()
|
||||
|
||||
init_kwargs = self.init_kwargs.copy()
|
||||
init_kwargs["config"] = values
|
||||
self.run = wandb.init(project=self.run_name, **init_kwargs)
|
||||
else:
|
||||
wandb.config.update(values, allow_val_change=True)
|
||||
logger.debug("Stored initial configuration hyperparameters to WandB")
|
||||
|
||||
@on_main_process
|
||||
@ -411,6 +428,83 @@ class WandBTracker(GeneralTracker):
|
||||
logger.debug("WandB run closed")
|
||||
|
||||
|
||||
class TrackioTracker(GeneralTracker):
|
||||
"""
|
||||
A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.
|
||||
|
||||
Args:
|
||||
run_name (`str`):
|
||||
The name of the experiment run. Will be used as the `project` name when instantiating trackio.
|
||||
**kwargs (additional keyword arguments, *optional*):
|
||||
Additional key word arguments passed along to the `trackio.init` method. Refer to this
|
||||
[init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)
|
||||
to see all supported key word arguments.
|
||||
"""
|
||||
|
||||
name = "trackio"
|
||||
requires_logging_directory = False
|
||||
main_process_only = False
|
||||
|
||||
def __init__(self, run_name: str, **kwargs):
|
||||
super().__init__()
|
||||
self.run_name = run_name
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
@on_main_process
|
||||
def start(self):
|
||||
import trackio
|
||||
|
||||
self.run = trackio.init(project=self.run_name, **self.init_kwargs)
|
||||
logger.debug(f"Initialized trackio project {self.run_name}")
|
||||
logger.debug(
|
||||
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
||||
)
|
||||
|
||||
@property
|
||||
def tracker(self):
|
||||
return self.run
|
||||
|
||||
@on_main_process
|
||||
def store_init_configuration(self, values: dict):
|
||||
"""
|
||||
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
||||
|
||||
Args:
|
||||
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
||||
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
||||
`str`, `float`, `int`, or `None`.
|
||||
"""
|
||||
import trackio
|
||||
|
||||
trackio.config.update(values, allow_val_change=True)
|
||||
logger.debug("Stored initial configuration hyperparameters to trackio")
|
||||
|
||||
@on_main_process
|
||||
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Logs `values` to the current run.
|
||||
|
||||
Args:
|
||||
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
|
||||
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
|
||||
`str` to `float`/`int`.
|
||||
step (`int`, *optional*):
|
||||
The run step. If included, the log will be affiliated with this step.
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the `trackio.log` method.
|
||||
"""
|
||||
self.run.log(values, **kwargs)
|
||||
logger.debug("Successfully logged to trackio")
|
||||
|
||||
@on_main_process
|
||||
def finish(self):
|
||||
"""
|
||||
Closes `trackio` run
|
||||
"""
|
||||
self.run.finish()
|
||||
logger.debug("trackio run closed")
|
||||
|
||||
|
||||
class CometMLTracker(GeneralTracker):
|
||||
"""
|
||||
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
|
||||
@ -1061,6 +1155,106 @@ class DVCLiveTracker(GeneralTracker):
|
||||
self.live.end()
|
||||
|
||||
|
||||
class SwanLabTracker(GeneralTracker):
|
||||
"""
|
||||
A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.
|
||||
|
||||
Args:
|
||||
run_name (`str`):
|
||||
The name of the experiment run.
|
||||
**kwargs (additional keyword arguments, *optional*):
|
||||
Additional key word arguments passed along to the `swanlab.init` method.
|
||||
"""
|
||||
|
||||
name = "swanlab"
|
||||
requires_logging_directory = False
|
||||
main_process_only = False
|
||||
|
||||
def __init__(self, run_name: str, **kwargs):
|
||||
super().__init__()
|
||||
self.run_name = run_name
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
@on_main_process
|
||||
def start(self):
|
||||
import swanlab
|
||||
|
||||
self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
|
||||
swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
|
||||
logger.debug(f"Initialized SwanLab project {self.run_name}")
|
||||
logger.debug(
|
||||
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
|
||||
)
|
||||
|
||||
@property
|
||||
def tracker(self):
|
||||
return self.run
|
||||
|
||||
@on_main_process
|
||||
def store_init_configuration(self, values: dict):
|
||||
"""
|
||||
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
|
||||
|
||||
Args:
|
||||
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
|
||||
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
|
||||
`str`, `float`, `int`, or `None`.
|
||||
"""
|
||||
import swanlab
|
||||
|
||||
swanlab.config.update(values, allow_val_change=True)
|
||||
logger.debug("Stored initial configuration hyperparameters to SwanLab")
|
||||
|
||||
@on_main_process
|
||||
def log(self, values: dict, step: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Logs `values` to the current run.
|
||||
|
||||
Args:
|
||||
data : Dict[str, DataType]
|
||||
Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
|
||||
`float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
|
||||
step : int, optional
|
||||
The step number of the current data, if not provided, it will be automatically incremented.
|
||||
If step is duplicated, the data will be ignored.
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the `swanlab.log` method. Likes:
|
||||
print_to_console : bool, optional
|
||||
Whether to print the data to the console, the default is False.
|
||||
"""
|
||||
self.run.log(values, step=step, **kwargs)
|
||||
logger.debug("Successfully logged to SwanLab")
|
||||
|
||||
@on_main_process
|
||||
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Logs `images` to the current run.
|
||||
|
||||
Args:
|
||||
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
|
||||
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
|
||||
step (`int`, *optional*):
|
||||
The run step. If included, the log will be affiliated with this step.
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the `swanlab.log` method. Likes:
|
||||
print_to_console : bool, optional
|
||||
Whether to print the data to the console, the default is False.
|
||||
"""
|
||||
import swanlab
|
||||
|
||||
for k, v in values.items():
|
||||
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
|
||||
logger.debug("Successfully logged images to SwanLab")
|
||||
|
||||
@on_main_process
|
||||
def finish(self):
|
||||
"""
|
||||
Closes `swanlab` writer
|
||||
"""
|
||||
self.run.finish()
|
||||
logger.debug("SwanLab run closed")
|
||||
|
||||
|
||||
LOGGER_TYPE_TO_CLASS = {
|
||||
"aim": AimTracker,
|
||||
"comet_ml": CometMLTracker,
|
||||
@ -1069,6 +1263,8 @@ LOGGER_TYPE_TO_CLASS = {
|
||||
"wandb": WandBTracker,
|
||||
"clearml": ClearMLTracker,
|
||||
"dvclive": DVCLiveTracker,
|
||||
"swanlab": SwanLabTracker,
|
||||
"trackio": TrackioTracker,
|
||||
}
|
||||
|
||||
|
||||
@ -1090,9 +1286,12 @@ def filter_trackers(
|
||||
- `"all"`
|
||||
- `"tensorboard"`
|
||||
- `"wandb"`
|
||||
- `"trackio"`
|
||||
- `"aim"`
|
||||
- `"comet_ml"`
|
||||
- `"mlflow"`
|
||||
- `"dvclive"`
|
||||
- `"swanlab"`
|
||||
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
|
||||
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
|
||||
logging_dir (`str`, `os.PathLike`, *optional*):
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..parallelism_config import ParallelismConfig
|
||||
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
|
||||
from .constants import (
|
||||
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
||||
@ -60,7 +61,9 @@ from .dataclasses import (
|
||||
SageMakerDistributedType,
|
||||
TensorInformation,
|
||||
TERecipeKwargs,
|
||||
TorchContextParallelConfig,
|
||||
TorchDynamoPlugin,
|
||||
TorchTensorParallelConfig,
|
||||
TorchTensorParallelPlugin,
|
||||
add_model_config_to_megatron_parser,
|
||||
)
|
||||
@ -121,6 +124,7 @@ from .imports import (
|
||||
is_sagemaker_available,
|
||||
is_schedulefree_available,
|
||||
is_sdaa_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_timm_available,
|
||||
is_torch_xla_available,
|
||||
@ -128,7 +132,9 @@ from .imports import (
|
||||
is_torchdata_available,
|
||||
is_torchdata_stateful_dataloader_available,
|
||||
is_torchvision_available,
|
||||
is_trackio_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformer_engine_mxfp8_available,
|
||||
is_transformers_available,
|
||||
is_triton_available,
|
||||
is_wandb_available,
|
||||
@ -281,6 +287,7 @@ from .other import (
|
||||
is_port_in_use,
|
||||
load,
|
||||
merge_dicts,
|
||||
model_has_dtensor,
|
||||
recursive_getattr,
|
||||
save,
|
||||
wait_for_everyone,
|
||||
|
||||
@ -314,7 +314,7 @@ def _replace_with_bnb_layers(
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
||||
"""
|
||||
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
||||
import bitsandbytes as bnb
|
||||
|
||||
@ -44,7 +44,6 @@ FSDP_PYTORCH_VERSION = (
|
||||
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
|
||||
)
|
||||
FSDP2_PYTORCH_VERSION = "2.6.0"
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION = "2.7.0"
|
||||
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
||||
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
|
||||
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
|
||||
@ -52,7 +51,9 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
|
||||
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
|
||||
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
|
||||
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
|
||||
BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
|
||||
@ -32,8 +32,9 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_a
|
||||
import torch
|
||||
|
||||
from .constants import (
|
||||
BETA_CP_AVAILABLE_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION,
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
|
||||
FSDP2_PYTORCH_VERSION,
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_BACKWARD_PREFETCH,
|
||||
@ -59,6 +60,7 @@ if TYPE_CHECKING:
|
||||
# Mock imports for type checking
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -185,7 +187,9 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
|
||||
comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO
|
||||
comm_wrapper: Literal[
|
||||
DDPCommunicationHookType.NO, DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16
|
||||
DDPCommunicationHookType.NO,
|
||||
DDPCommunicationHookType.FP16,
|
||||
DDPCommunicationHookType.BF16,
|
||||
] = DDPCommunicationHookType.NO
|
||||
comm_state_option: dict = field(default_factory=dict)
|
||||
|
||||
@ -193,7 +197,10 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}
|
||||
|
||||
def register_comm_hook(self, model):
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks, powerSGD_hook
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks,
|
||||
powerSGD_hook,
|
||||
)
|
||||
|
||||
hook_map: dict[DDPCommunicationHookType, Callable] = {
|
||||
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,
|
||||
@ -216,7 +223,11 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
if hook:
|
||||
state = (
|
||||
powerSGD_hook.PowerSGDState(None, **self.comm_state_option)
|
||||
if self.comm_hook in (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BATCHED_POWER_SGD)
|
||||
if self.comm_hook
|
||||
in (
|
||||
DDPCommunicationHookType.POWER_SGD,
|
||||
DDPCommunicationHookType.BATCHED_POWER_SGD,
|
||||
)
|
||||
else None
|
||||
)
|
||||
model.register_comm_hook(
|
||||
@ -290,7 +301,7 @@ class InitProcessGroupKwargs(KwargsHandler):
|
||||
# Literals
|
||||
Backend = Literal["MSAMP", "TE"]
|
||||
OptLevel = Literal["O1", "O2"]
|
||||
FP8Format = Literal["E4M3", "HYBRID"]
|
||||
FP8Format = Literal["HYBRID", "E4M3", "E5M2"]
|
||||
AmaxComputeAlgorithm = Literal["max", "most_recent"]
|
||||
|
||||
|
||||
@ -343,8 +354,8 @@ class TERecipeKwargs(KwargsHandler):
|
||||
interval (`int`, *optional*, default to 1):
|
||||
The interval to use for how often the scaling factor is recomputed.
|
||||
fp8_format (`str`, *optional*, default to "HYBRID"):
|
||||
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
|
||||
`E4M3` for evaluation)
|
||||
The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for
|
||||
training, `E4M3` or `E5M2` for evaluation)
|
||||
amax_history_len (`int`, *optional*, default to 1024):
|
||||
The length of the history to use for the scaling factor computation
|
||||
amax_compute_algo (`str`, *optional*, default to "most_recent"):
|
||||
@ -360,6 +371,7 @@ class TERecipeKwargs(KwargsHandler):
|
||||
amax_history_len: int = None
|
||||
amax_compute_algo: AmaxComputeAlgorithm = None
|
||||
override_linear_precision: tuple[bool, bool, bool] = None
|
||||
use_mxfp8_block_scaling: bool = None
|
||||
|
||||
def __post_init__(self):
|
||||
env_prefix = "ACCELERATE_FP8_"
|
||||
@ -388,6 +400,8 @@ class TERecipeKwargs(KwargsHandler):
|
||||
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
|
||||
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
|
||||
self.override_linear_precision = (fprop, dgrad, wgrad)
|
||||
if self.use_mxfp8_block_scaling is None:
|
||||
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -583,7 +597,6 @@ class DistributedType(str, enum.Enum):
|
||||
MULTI_XPU = "MULTI_XPU"
|
||||
DEEPSPEED = "DEEPSPEED"
|
||||
FSDP = "FSDP"
|
||||
TP = "TP"
|
||||
XLA = "XLA"
|
||||
MEGATRON_LM = "MEGATRON_LM"
|
||||
MULTI_HPU = "MULTI_HPU"
|
||||
@ -617,8 +630,10 @@ class FP8BackendType(str, enum.Enum):
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
TE = "TE"
|
||||
MSAMP = "MSAMP"
|
||||
AO = "AO"
|
||||
|
||||
|
||||
class ComputeEnvironment(str, enum.Enum):
|
||||
@ -668,7 +683,7 @@ class DynamoBackend(str, BaseEnum):
|
||||
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
|
||||
- **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
|
||||
more](https://github.com/intel/intel-extension-for-pytorch).
|
||||
- **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/)
|
||||
- **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)
|
||||
- **HPU_BACKEND** -- Uses HPU backend for inference optimizations.
|
||||
|
||||
"""
|
||||
@ -700,18 +715,24 @@ class LoggerType(BaseEnum):
|
||||
- **ALL** -- all available trackers in the environment that are supported
|
||||
- **TENSORBOARD** -- TensorBoard as an experiment tracker
|
||||
- **WANDB** -- wandb as an experiment tracker
|
||||
- **TRACKIO** -- trackio as an experiment tracker
|
||||
- **COMETML** -- comet_ml as an experiment tracker
|
||||
- **MLFLOW** -- mlflow as an experiment tracker
|
||||
- **CLEARML** -- clearml as an experiment tracker
|
||||
- **DVCLIVE** -- dvclive as an experiment tracker
|
||||
- **SWANLAB** -- swanlab as an experiment tracker
|
||||
"""
|
||||
|
||||
ALL = "all"
|
||||
AIM = "aim"
|
||||
TENSORBOARD = "tensorboard"
|
||||
WANDB = "wandb"
|
||||
TRACKIO = "trackio"
|
||||
COMETML = "comet_ml"
|
||||
MLFLOW = "mlflow"
|
||||
CLEARML = "clearml"
|
||||
DVCLIVE = "dvclive"
|
||||
SWANLAB = "swanlab"
|
||||
|
||||
|
||||
class PrecisionType(str, BaseEnum):
|
||||
@ -783,9 +804,9 @@ class DataLoaderConfiguration:
|
||||
all workers.
|
||||
use_seedable_sampler (`bool`, defaults to `False`):
|
||||
Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures
|
||||
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||
also be ran with [`~utils.set_seed`] for the best results.
|
||||
training results are fully reproducible using a different sampling technique. While seed-to-seed results
|
||||
may differ, on average the differences are negligible when using multiple different seeds to compare.
|
||||
Should also be ran with [`~utils.set_seed`] for the best results.
|
||||
data_seed (`int`, defaults to `None`):
|
||||
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
|
||||
will use the current default seed from torch.
|
||||
@ -828,8 +849,8 @@ class DataLoaderConfiguration:
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
|
||||
"Ensures training results are fully reproducable using a different sampling technique. "
|
||||
"While seed-to-seed results may differ, on average the differences are neglible when using"
|
||||
"Ensures training results are fully reproducible using a different sampling technique. "
|
||||
"While seed-to-seed results may differ, on average the differences are negligible when using"
|
||||
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
||||
},
|
||||
)
|
||||
@ -935,7 +956,7 @@ class GradientAccumulationPlugin(KwargsHandler):
|
||||
sync_with_dataloader (`bool`, *optional*, defaults to `True`):
|
||||
Whether to synchronize setting the gradients when at the end of the dataloader.
|
||||
sync_each_batch (`bool`, *optional*):
|
||||
Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory
|
||||
Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory
|
||||
requirements when using gradient accumulation with distributed training, at expense of speed.
|
||||
|
||||
Example:
|
||||
@ -948,7 +969,10 @@ class GradientAccumulationPlugin(KwargsHandler):
|
||||
```
|
||||
"""
|
||||
|
||||
num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."})
|
||||
num_steps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of steps to accumulate gradients for."},
|
||||
)
|
||||
adjust_scheduler: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
@ -999,12 +1023,22 @@ class TorchDynamoPlugin(KwargsHandler):
|
||||
metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"},
|
||||
)
|
||||
mode: str = field(
|
||||
default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}
|
||||
default=None,
|
||||
metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"},
|
||||
)
|
||||
fullgraph: bool = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether it is ok to break model into several subgraphs"},
|
||||
)
|
||||
fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"})
|
||||
dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"})
|
||||
options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."})
|
||||
disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"})
|
||||
options: Any = field(
|
||||
default=None,
|
||||
metadata={"help": "A dictionary of options to pass to the backend."},
|
||||
)
|
||||
disable: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Turn torch.compile() into a no-op for testing"},
|
||||
)
|
||||
|
||||
use_regional_compilation: bool = field(
|
||||
default=None,
|
||||
@ -1183,7 +1217,7 @@ class DeepSpeedPlugin:
|
||||
|
||||
if self.zero3_save_16bit_model is None:
|
||||
self.zero3_save_16bit_model = (
|
||||
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true"
|
||||
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true"
|
||||
)
|
||||
if self.enable_msamp is None:
|
||||
self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP"
|
||||
@ -1236,13 +1270,13 @@ class DeepSpeedPlugin:
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
"nvme_path": self.offload_optimizer_nvme_path
|
||||
if self.offload_optimizer_device == "nvme"
|
||||
else None,
|
||||
"nvme_path": (
|
||||
self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None
|
||||
),
|
||||
},
|
||||
"offload_param": {
|
||||
"device": self.offload_param_device,
|
||||
"nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None,
|
||||
"nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None),
|
||||
},
|
||||
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
|
||||
},
|
||||
@ -1255,7 +1289,13 @@ class DeepSpeedPlugin:
|
||||
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
|
||||
if self.zero3_init_flag is None:
|
||||
self.zero3_init_flag = (
|
||||
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
|
||||
str_to_bool(
|
||||
os.environ.get(
|
||||
"ACCELERATE_DEEPSPEED_ZERO3_INIT",
|
||||
str(self.hf_ds_config.is_zero3()),
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
|
||||
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
|
||||
@ -1272,7 +1312,10 @@ class DeepSpeedPlugin:
|
||||
)
|
||||
if self.msamp_opt_level not in ["O1", "O2"]:
|
||||
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].")
|
||||
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
|
||||
self.deepspeed_config["msamp"] = {
|
||||
"enabled": True,
|
||||
"opt_level": self.msamp_opt_level,
|
||||
}
|
||||
|
||||
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
|
||||
mismatches = [] if mismatches is None else mismatches
|
||||
@ -1317,7 +1360,11 @@ class DeepSpeedPlugin:
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict):
|
||||
self.deepspeed_config_process(
|
||||
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs
|
||||
prefix=prefix + key + ".",
|
||||
mismatches=mismatches,
|
||||
config=value,
|
||||
must_match=must_match,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
|
||||
@ -1344,7 +1391,10 @@ class DeepSpeedPlugin:
|
||||
|
||||
if mixed_precision == "fp8" and self.enable_msamp:
|
||||
if "msamp" not in ds_config:
|
||||
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
|
||||
ds_config["msamp"] = {
|
||||
"enabled": True,
|
||||
"opt_level": self.msamp_opt_level,
|
||||
}
|
||||
|
||||
if mixed_precision != "no":
|
||||
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
|
||||
@ -1376,9 +1426,15 @@ class DeepSpeedPlugin:
|
||||
del ds_config["train_batch_size"]
|
||||
|
||||
if compare_versions("transformers", "<", "4.46"):
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config
|
||||
from transformers.deepspeed import (
|
||||
HfDeepSpeedConfig,
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
else:
|
||||
from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config
|
||||
from transformers.integrations import (
|
||||
HfDeepSpeedConfig,
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
|
||||
unset_hf_deepspeed_config()
|
||||
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
|
||||
@ -1497,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
|
||||
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
|
||||
Backward prefetch strategy to use. Should be either a `str` or an instance of
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
|
||||
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
|
||||
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
|
||||
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
|
||||
`reduce_dtype`, and `buffer_dtype`.
|
||||
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
|
||||
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
|
||||
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
|
||||
@ -1509,8 +1567,9 @@ class FullyShardedDataParallelPlugin:
|
||||
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
|
||||
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
|
||||
A list of modules to ignore when wrapping with FSDP.
|
||||
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
|
||||
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
|
||||
using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.
|
||||
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
|
||||
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
|
||||
`sharded_state_dict`.
|
||||
@ -1547,11 +1606,6 @@ class FullyShardedDataParallelPlugin:
|
||||
min_num_params (`Optional[int]`, defaults to `None`):
|
||||
The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`
|
||||
is `size_based_wrap`.
|
||||
cp_size (`int`, defaults to `1`):
|
||||
The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be
|
||||
raised. Defaults to 1 (CP not applied).
|
||||
cp_comm_strategy (`str`, defaults to `allgather`):
|
||||
The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2.
|
||||
"""
|
||||
|
||||
fsdp_version: int = field(
|
||||
@ -1581,7 +1635,12 @@ class FullyShardedDataParallelPlugin:
|
||||
},
|
||||
)
|
||||
mixed_precision_policy: Optional[
|
||||
Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"]
|
||||
Union[
|
||||
dict,
|
||||
str,
|
||||
"torch.distributed.fsdp.MixedPrecision",
|
||||
"torch.distributed.fsdp.MixedPrecisionPolicy",
|
||||
]
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1599,13 +1658,17 @@ class FullyShardedDataParallelPlugin:
|
||||
},
|
||||
)
|
||||
)
|
||||
cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field(
|
||||
cpu_offload: Union[
|
||||
bool,
|
||||
"torch.distributed.fsdp.CPUOffload",
|
||||
"torch.distributed.fsdp.CPUOffloadPolicy",
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
|
||||
},
|
||||
)
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
|
||||
ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
|
||||
)
|
||||
@ -1626,7 +1689,10 @@ class FullyShardedDataParallelPlugin:
|
||||
metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."},
|
||||
)
|
||||
optim_state_dict_config: Optional[
|
||||
Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"]
|
||||
Union[
|
||||
"torch.distributed.fsdp.FullOptimStateDictConfig",
|
||||
"torch.distributed.fsdp.ShardedOptimStateDictConfig",
|
||||
]
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1697,24 +1763,9 @@ class FullyShardedDataParallelPlugin:
|
||||
"help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`."
|
||||
},
|
||||
)
|
||||
cp_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be raised. Defaults to 1 (CP not applied)"
|
||||
},
|
||||
)
|
||||
cp_comm_strategy: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2. Defaults to `allgather`."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from torch.distributed.fsdp import (
|
||||
BackwardPrefetch,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
|
||||
_fsdp2_warnings = set()
|
||||
|
||||
@ -1748,7 +1799,8 @@ class FullyShardedDataParallelPlugin:
|
||||
# Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
|
||||
if self.reshard_after_forward is None and self.sharding_strategy is None:
|
||||
reshard_after_forward = os.environ.get(
|
||||
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
|
||||
env_prefix + "RESHARD_AFTER_FORWARD",
|
||||
"true" if self.fsdp_version == 2 else "FULL_SHARD",
|
||||
)
|
||||
if self.fsdp_version == 2:
|
||||
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
|
||||
@ -1805,7 +1857,10 @@ class FullyShardedDataParallelPlugin:
|
||||
raise ValueError(
|
||||
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
|
||||
self.auto_wrap_policy = transformer_auto_wrap_policy
|
||||
@ -1849,6 +1904,9 @@ class FullyShardedDataParallelPlugin:
|
||||
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
|
||||
)
|
||||
|
||||
if self.ignored_modules is None:
|
||||
self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
|
||||
|
||||
if self.cpu_ram_efficient_loading is None:
|
||||
self.cpu_ram_efficient_loading = (
|
||||
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
|
||||
@ -1871,29 +1929,11 @@ class FullyShardedDataParallelPlugin:
|
||||
)
|
||||
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
|
||||
|
||||
if self.cp_size is None:
|
||||
self.cp_size = int(os.environ.get(env_prefix + "CP_SIZE", "1"))
|
||||
|
||||
if self.cp_size > 1 and self.fsdp_version != 2:
|
||||
raise ValueError(
|
||||
f"cp_size set to {self.cp_size}. This is not supported with FSDP1, please set to 1 or use `fsdp_version=2`"
|
||||
)
|
||||
|
||||
if self.cp_size > 1 and not is_torch_version(">=", CONTEXT_PARALLEL_PYTORCH_VERSION):
|
||||
raise ValueError(
|
||||
f"cp_size set to {self.cp_size}. This is not supported with PyTorch < {CONTEXT_PARALLEL_PYTORCH_VERSION}, please set to None or upgrade your PyTorch version."
|
||||
)
|
||||
|
||||
if self.cp_comm_strategy is None:
|
||||
self.cp_comm_strategy = os.environ.get(env_prefix + "CP_COMM_STRATEGY", "allgather")
|
||||
|
||||
# No need to further check versions, as that check is done in the `context_parallel_size` check
|
||||
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
|
||||
raise ValueError(
|
||||
f"cp_comm_strategy set to {self.cp_comm_strategy}. Must be one of ['allgather', 'alltoall']."
|
||||
)
|
||||
|
||||
if isinstance(self.mixed_precision_policy, dict):
|
||||
if isinstance(self.mixed_precision_policy, str):
|
||||
# override is True since self.mixed_precision_policy is not None
|
||||
# has to be overwritten with the correct mixed precision object
|
||||
self.set_mixed_precision(self.mixed_precision_policy, override=True)
|
||||
elif isinstance(self.mixed_precision_policy, dict):
|
||||
self.set_mixed_precision(self.mixed_precision_policy)
|
||||
if self.mixed_precision_policy is not None:
|
||||
self.validate_mixed_precision_policy()
|
||||
@ -1918,7 +1958,12 @@ class FullyShardedDataParallelPlugin:
|
||||
# Create a function that will be used to initialize the parameters of the model
|
||||
# when using `sync_module_states`
|
||||
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
|
||||
|
||||
if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None:
|
||||
_fsdp2_warnings.add(
|
||||
"FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0"
|
||||
"Setting ignored_modules to None."
|
||||
)
|
||||
self.ignored_modules = None
|
||||
# Single warning for all deprecation warnings due to FSDP2 conversion
|
||||
if _fsdp2_warnings:
|
||||
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))
|
||||
@ -1942,7 +1987,8 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
if self.state_dict_type is None:
|
||||
self.state_dict_type = os.environ.get(
|
||||
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
|
||||
"FSDP_STATE_DICT_TYPE",
|
||||
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
|
||||
)
|
||||
if isinstance(self.state_dict_type, str):
|
||||
if self.state_dict_type.isdigit():
|
||||
@ -1969,10 +2015,13 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
def set_auto_wrap_policy(self, model):
|
||||
"""
|
||||
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
|
||||
Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the
|
||||
`transformer_cls_to_wrap`
|
||||
"""
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# First base off of `_no_split_modules`
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
@ -2109,33 +2158,57 @@ class TorchTensorParallelPlugin:
|
||||
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
|
||||
)
|
||||
|
||||
# torch_device_mesh is fo type "torch.distributed.DeviceMesh"
|
||||
# torch_device_mesh is of type "torch.distributed.DeviceMesh"
|
||||
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchContextParallelConfig:
|
||||
"""
|
||||
This class holds the configuration for context parallelism in PyTorch.
|
||||
"""
|
||||
|
||||
cp_comm_strategy: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.tp_size, int):
|
||||
raise ValueError(f"`tp_size` set to {self.tp_size}, please set to an `int`.")
|
||||
|
||||
if self.tp_size <= 1:
|
||||
raise ValueError("`tp_size` must be greater than 1.")
|
||||
|
||||
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION):
|
||||
raise ValueError(
|
||||
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
|
||||
f"Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. "
|
||||
"Please upgrade your PyTorch version."
|
||||
)
|
||||
if self.cp_comm_strategy is None:
|
||||
self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather")
|
||||
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
|
||||
raise ValueError(
|
||||
f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'."
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# support for other devices has to be investigated
|
||||
if is_hpu_available(init_hccl=True):
|
||||
device = "hpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
|
||||
mesh_dim_name = "tp"
|
||||
@dataclass
|
||||
class TorchTensorParallelConfig:
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize your torch tensor parallelism.
|
||||
"""
|
||||
|
||||
# device mesh is not used for model sharding
|
||||
# it is only used for preparing data loader
|
||||
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
|
||||
enable_async_tp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
raise ValueError(
|
||||
f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. "
|
||||
"Please upgrade your PyTorch version."
|
||||
)
|
||||
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
|
||||
if self.enable_async_tp:
|
||||
warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -2190,7 +2263,7 @@ class MegatronLMPlugin:
|
||||
lr_warmup_fraction (`float`, defaults to `None`):
|
||||
Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.
|
||||
min_lr (`float`, defaults to `0`):
|
||||
Minumum value for learning rate. The scheduler clip values below this threshold.
|
||||
Minimum value for learning rate. The scheduler clip values below this threshold.
|
||||
consumed_samples (`List`, defaults to `None`):
|
||||
Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.
|
||||
no_wd_decay_cond (`Optional`, defaults to `None`):
|
||||
@ -2239,7 +2312,8 @@ class MegatronLMPlugin:
|
||||
pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."})
|
||||
num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."})
|
||||
gradient_clipping: float = field(
|
||||
default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}
|
||||
default=None,
|
||||
metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"},
|
||||
)
|
||||
sequence_parallelism: bool = field(
|
||||
default=None,
|
||||
@ -2254,7 +2328,8 @@ class MegatronLMPlugin:
|
||||
metadata={"help": "enable distributed optimizer"},
|
||||
)
|
||||
pipeline_model_parallel_split_rank: int = field(
|
||||
default=None, metadata={"help": "Rank where encoder and decoder should be split."}
|
||||
default=None,
|
||||
metadata={"help": "Rank where encoder and decoder should be split."},
|
||||
)
|
||||
num_layers_per_virtual_pipeline_stage: int = field(
|
||||
default=None, metadata={"help": "Number of layers per virtual pipeline stage."}
|
||||
@ -2315,7 +2390,7 @@ class MegatronLMPlugin:
|
||||
)
|
||||
min_lr: float = field(
|
||||
default=0,
|
||||
metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."},
|
||||
metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."},
|
||||
)
|
||||
consumed_samples: list[int] = field(
|
||||
default=None,
|
||||
@ -2351,10 +2426,12 @@ class MegatronLMPlugin:
|
||||
metadata={"help": "Whether to set all logging options."},
|
||||
)
|
||||
eval_iters: int = field(
|
||||
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}
|
||||
default=100,
|
||||
metadata={"help": "Number of iterations to run for evaluation validation/test for."},
|
||||
)
|
||||
eval_interval: int = field(
|
||||
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
|
||||
default=1000,
|
||||
metadata={"help": "Interval between running evaluation on validation set."},
|
||||
)
|
||||
return_logits: bool = field(
|
||||
default=False,
|
||||
@ -2721,7 +2798,8 @@ class BnbQuantizationConfig:
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
|
||||
|
||||
llm_int8_threshold: float = field(
|
||||
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
|
||||
default=6.0,
|
||||
metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"},
|
||||
)
|
||||
|
||||
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
|
||||
|
||||
@ -261,22 +261,36 @@ class DeepSpeedEngineWrapper:
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def backward(self, loss, **kwargs):
|
||||
def backward(self, loss, sync_gradients=True, **kwargs):
|
||||
# Set gradient accumulation boundary based on Accelerate's sync_gradients state
|
||||
# This tells DeepSpeed whether this is the final micro-batch before gradient sync
|
||||
self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)
|
||||
|
||||
# runs backpropagation and handles mixed precision
|
||||
self.engine.backward(loss, **kwargs)
|
||||
|
||||
# Deepspeed's `engine.step` performs the following operations:
|
||||
# - gradient accumulation check
|
||||
# - gradient clipping
|
||||
# - optimizer step
|
||||
# - zero grad
|
||||
# - checking overflow
|
||||
# - lr_scheduler step (only if engine.lr_scheduler is not None)
|
||||
self.engine.step()
|
||||
# Only perform step and related operations at gradient accumulation boundaries
|
||||
if sync_gradients:
|
||||
# Deepspeed's `engine.step` performs the following operations:
|
||||
# - gradient accumulation check
|
||||
# - gradient clipping
|
||||
# - optimizer step
|
||||
# - zero grad
|
||||
# - checking overflow
|
||||
# - lr_scheduler step (only if engine.lr_scheduler is not None)
|
||||
self.engine.step()
|
||||
# and this plugin overrides the above calls with no-ops when Accelerate runs under
|
||||
# Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
|
||||
# training loop that works transparently under many training regimes.
|
||||
|
||||
def get_global_grad_norm(self):
|
||||
"""Get the global gradient norm from DeepSpeed engine."""
|
||||
grad_norm = self.engine.get_global_grad_norm()
|
||||
# Convert to scalar if it's a tensor
|
||||
if hasattr(grad_norm, "item"):
|
||||
return grad_norm.item()
|
||||
return grad_norm
|
||||
|
||||
|
||||
class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
|
||||
"""
|
||||
|
||||
@ -149,7 +149,7 @@ def check_cuda_p2p_ib_support():
|
||||
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
|
||||
the 3090.
|
||||
|
||||
Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
|
||||
Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
|
||||
"""
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
|
||||
@ -14,12 +14,14 @@
|
||||
import copy
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -179,10 +181,9 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
|
||||
else nullcontext()
|
||||
)
|
||||
sd_options = _prepare_sd_options(fsdp_plugin)
|
||||
|
||||
with ctx:
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
if type(model) is not FSDP and accelerator.process_index != 0:
|
||||
if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
|
||||
if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
|
||||
raise ValueError(
|
||||
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
|
||||
@ -192,7 +193,12 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
|
||||
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
|
||||
input_model_file = os.path.join(input_dir, weights_name)
|
||||
logger.info(f"Loading model from {input_model_file}")
|
||||
state_dict = torch.load(input_model_file, weights_only=True)
|
||||
# we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
|
||||
load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
|
||||
if load_model:
|
||||
state_dict = torch.load(input_model_file, weights_only=True)
|
||||
else:
|
||||
state_dict = {}
|
||||
logger.info(f"Model loaded from {input_model_file}")
|
||||
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
||||
weights_name = (
|
||||
@ -299,13 +305,15 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
|
||||
optim_state = torch.load(input_optimizer_file, weights_only=True)
|
||||
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
|
||||
else:
|
||||
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
|
||||
|
||||
ckpt_dir = (
|
||||
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
|
||||
if f"{OPTIMIZER_NAME}" not in input_dir
|
||||
else input_dir
|
||||
)
|
||||
logger.info(f"Loading Optimizer from {ckpt_dir}")
|
||||
optim_state = {"optimizer": optimizer.state_dict()}
|
||||
optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)}
|
||||
dist_cp.load(
|
||||
optim_state,
|
||||
checkpoint_id=ckpt_dir,
|
||||
@ -498,10 +506,10 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
|
||||
if accelerator.is_main_process:
|
||||
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
|
||||
full_param = full_param.detach().cuda()
|
||||
mesh = sharded_param.device_mesh
|
||||
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
|
||||
device_mesh = sharded_param.device_mesh
|
||||
full_param = full_param.detach().to(device_mesh.device_type)
|
||||
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
|
||||
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
@ -512,10 +520,10 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
|
||||
else:
|
||||
for param_name, sharded_param in meta_sharded_sd.items():
|
||||
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
|
||||
mesh = sharded_param.device_mesh
|
||||
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
|
||||
device_mesh = sharded_param.device_mesh
|
||||
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
|
||||
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
|
||||
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
@ -544,6 +552,11 @@ def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping:
|
||||
indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
|
||||
correct and weights wouldn't get updated.
|
||||
"""
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
accessor_mapping = {}
|
||||
|
||||
accessor_mapping[DTensor] = "_local_tensor"
|
||||
try:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
|
||||
@ -611,16 +624,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
original_sd = model.state_dict()
|
||||
|
||||
mesh = getattr(accelerator.state, "torch_device_mesh", None)
|
||||
mesh = getattr(accelerator, "torch_device_mesh", None)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": mesh["fsdp_cp"] if mesh else None,
|
||||
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
||||
}
|
||||
if fsdp2_plugin.ignored_modules is not None:
|
||||
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
|
||||
fsdp2_plugin.ignored_modules, model, accelerator.device
|
||||
)
|
||||
|
||||
model_has_params4bit = False
|
||||
for name, param in model.named_parameters():
|
||||
@ -634,7 +650,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
|
||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||
|
||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||
@ -782,5 +798,32 @@ def fsdp2_canonicalize_names(named_params: dict) -> dict:
|
||||
k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
|
||||
}
|
||||
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
|
||||
named_params = {k.replace("_cp_wrapped_model.", ""): v for k, v in named_params.items()}
|
||||
return named_params
|
||||
|
||||
|
||||
def get_parameters_from_modules(
|
||||
modules: Union[Iterable[torch.nn.Module], str], model, device
|
||||
) -> set[torch.nn.Parameter]:
|
||||
"""Converts modules to parameters where modules can be a string or list of torch.nn.Module
|
||||
|
||||
Args:
|
||||
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
|
||||
|
||||
Returns:
|
||||
`List[torch.nn.Parameter]`: List of parameters
|
||||
"""
|
||||
if modules is None:
|
||||
return None
|
||||
parameters = []
|
||||
# code taken from accelerate while preparing kwargs for FSDP
|
||||
if isinstance(modules, str):
|
||||
reg = re.compile(modules)
|
||||
mapped_modules = []
|
||||
for name, module in model.named_modules():
|
||||
if reg.fullmatch(name):
|
||||
module.to(device)
|
||||
mapped_modules.append(module)
|
||||
modules = mapped_modules
|
||||
for module in modules:
|
||||
parameters.extend(list(module.parameters()))
|
||||
return set(parameters)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from functools import lru_cache, wraps
|
||||
|
||||
@ -113,6 +114,14 @@ def is_transformer_engine_available():
|
||||
return _is_package_available("transformer_engine", "transformer-engine")
|
||||
|
||||
|
||||
def is_transformer_engine_mxfp8_available():
|
||||
if _is_package_available("transformer_engine", "transformer-engine"):
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
return te.fp8.check_mxfp8_support()[0]
|
||||
return False
|
||||
|
||||
|
||||
def is_lomo_available():
|
||||
return _is_package_available("lomo_optim")
|
||||
|
||||
@ -173,7 +182,7 @@ def is_bf16_available(ignore_tpu=False):
|
||||
if is_xpu_available():
|
||||
return torch.xpu.is_bf16_supported()
|
||||
if is_mps_available():
|
||||
return False
|
||||
return torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
return True
|
||||
|
||||
|
||||
@ -281,6 +290,14 @@ def is_comet_ml_available():
|
||||
return _is_package_available("comet_ml")
|
||||
|
||||
|
||||
def is_swanlab_available():
|
||||
return _is_package_available("swanlab")
|
||||
|
||||
|
||||
def is_trackio_available():
|
||||
return sys.version_info >= (3, 10) and _is_package_available("trackio")
|
||||
|
||||
|
||||
def is_boto3_available():
|
||||
return _is_package_available("boto3")
|
||||
|
||||
@ -397,7 +414,12 @@ def is_npu_available(check_device=False):
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
import torch_npu # noqa: F401
|
||||
# NOTE: importing torch_npu may raise error in some envs
|
||||
# e.g. inside cpu-only container with torch_npu installed
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
|
||||
@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
|
||||
value = getattr(args, arg)
|
||||
if value is not None:
|
||||
if arg == "fp8_override_linear_precision":
|
||||
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
|
||||
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
|
||||
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
|
||||
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
|
||||
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
|
||||
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
|
||||
else:
|
||||
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
|
||||
return current_env
|
||||
@ -328,8 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
||||
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
|
||||
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
||||
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
|
||||
current_env["FSDP_CP_SIZE"] = str(args.fsdp_cp_size)
|
||||
current_env["FSDP_CP_COMM_STRATEGY"] = str(args.fsdp_cp_comm_strategy)
|
||||
if getattr(args, "fsdp_ignored_modules", None) is not None:
|
||||
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
|
||||
|
||||
if args.use_megatron_lm:
|
||||
prefix = "MEGATRON_LM_"
|
||||
@ -349,6 +349,20 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
||||
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
|
||||
if args.enable_cpu_affinity:
|
||||
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
|
||||
|
||||
if not args.use_parallelism_config:
|
||||
return current_env
|
||||
|
||||
prefix = "PARALLELISM_CONFIG_"
|
||||
if args.use_parallelism_config:
|
||||
current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||
current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
|
||||
current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
|
||||
current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
|
||||
current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
|
||||
if args.parallelism_config_cp_size > 1:
|
||||
current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
|
||||
|
||||
return current_env
|
||||
|
||||
|
||||
|
||||
@ -873,7 +873,7 @@ def finish_mpu_init():
|
||||
_set_random_seed(args.seed, args.data_parallel_random_init)
|
||||
|
||||
|
||||
# intialize megatron setup
|
||||
# initialize megatron setup
|
||||
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
|
||||
accelerator.print("Initializing Megatron-LM")
|
||||
assert torch.cuda.is_available(), "Megatron requires CUDA."
|
||||
@ -1344,7 +1344,7 @@ class MegatronEngine(torch.nn.Module):
|
||||
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
|
||||
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
|
||||
|
||||
# We need the sizes of these tensors for the boradcast
|
||||
# We need the sizes of these tensors for the broadcast
|
||||
sizes_list = [
|
||||
prompts_tokens_tensor.size(0), # Batch size
|
||||
prompts_tokens_tensor.size(1),
|
||||
@ -1353,7 +1353,7 @@ class MegatronEngine(torch.nn.Module):
|
||||
# First, broadcast the sizes.
|
||||
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
|
||||
|
||||
# Now that we have the sizes, we can boradcast the tokens
|
||||
# Now that we have the sizes, we can broadcast the tokens
|
||||
# and length tensors.
|
||||
sizes = sizes_tensor.tolist()
|
||||
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)
|
||||
|
||||
@ -121,7 +121,7 @@ def find_executable_batch_size(
|
||||
):
|
||||
"""
|
||||
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||
CUDNN, the batch size is cut in half and passed to `function`
|
||||
CUDNN, the batch size is multiplied by 0.9 and passed to `function`
|
||||
|
||||
`function` must take in a `batch_size` parameter as its first argument.
|
||||
|
||||
@ -153,7 +153,7 @@ def find_executable_batch_size(
|
||||
|
||||
def reduce_batch_size_fn():
|
||||
nonlocal batch_size
|
||||
batch_size = batch_size // 2
|
||||
batch_size = int(batch_size * 0.9)
|
||||
return batch_size
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
|
||||
@ -169,7 +169,7 @@ def dtype_byte_size(dtype: torch.dtype):
|
||||
return 1 / 2
|
||||
elif dtype == CustomDtype.FP8:
|
||||
return 1
|
||||
elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
|
||||
elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
return 1
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
@ -222,6 +222,8 @@ def set_module_tensor_to_device(
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
fp16_statistics: Optional[torch.HalfTensor] = None,
|
||||
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
|
||||
non_blocking: bool = False,
|
||||
clear_cache: bool = True,
|
||||
):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
@ -245,6 +247,10 @@ def set_module_tensor_to_device(
|
||||
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
|
||||
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
|
||||
device for all others, instead of duplicating memory.
|
||||
non_blocking (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the device transfer will be asynchronous with respect to the host, if possible.
|
||||
clear_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clear the device cache after setting the tensor on the device.
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
@ -295,9 +301,9 @@ def set_module_tensor_to_device(
|
||||
|
||||
if dtype is None:
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
||||
value = value.to(old_value.dtype)
|
||||
value = value.to(old_value.dtype, non_blocking=non_blocking)
|
||||
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
value = value.to(dtype)
|
||||
value = value.to(dtype, non_blocking=non_blocking)
|
||||
|
||||
device_quantization = None
|
||||
with torch.no_grad():
|
||||
@ -305,8 +311,8 @@ def set_module_tensor_to_device(
|
||||
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
|
||||
if (
|
||||
param is not None
|
||||
and param.device.type != "cuda"
|
||||
and torch.device(device).type == "cuda"
|
||||
and param.device.type not in ("cuda", "xpu")
|
||||
and torch.device(device).type in ("cuda", "xpu")
|
||||
and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
|
||||
):
|
||||
device_quantization = device
|
||||
@ -326,15 +332,15 @@ def set_module_tensor_to_device(
|
||||
if "xpu" in str(device) and not is_xpu_available():
|
||||
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
|
||||
if value is None:
|
||||
new_value = old_value.to(device)
|
||||
new_value = old_value.to(device, non_blocking=non_blocking)
|
||||
if dtype is not None and device in ["meta", torch.device("meta")]:
|
||||
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
new_value = new_value.to(dtype)
|
||||
new_value = new_value.to(dtype, non_blocking=non_blocking)
|
||||
|
||||
if not is_buffer:
|
||||
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to(device)
|
||||
new_value = value.to(device, non_blocking=non_blocking)
|
||||
else:
|
||||
new_value = torch.tensor(value, device=device)
|
||||
if device_quantization is not None:
|
||||
@ -347,24 +353,30 @@ def set_module_tensor_to_device(
|
||||
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
|
||||
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
|
||||
# downcast to fp16 if any - needed for 8bit serialization
|
||||
new_value = new_value.to(torch.float16)
|
||||
new_value = new_value.to(torch.float16, non_blocking=non_blocking)
|
||||
# quantize module that are going to stay on the cpu so that we offload quantized weights
|
||||
if device == "cpu" and param_cls.__name__ == "Int8Params":
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
|
||||
new_value.CB = new_value.CB.to("cpu")
|
||||
new_value.SCB = new_value.SCB.to("cpu")
|
||||
else:
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
|
||||
device, non_blocking=non_blocking
|
||||
)
|
||||
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
|
||||
device, non_blocking=non_blocking
|
||||
)
|
||||
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
|
||||
new_value = new_value.to(device)
|
||||
new_value = new_value.to(device, non_blocking=non_blocking)
|
||||
else:
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
|
||||
device, non_blocking=non_blocking
|
||||
)
|
||||
|
||||
module._parameters[tensor_name] = new_value
|
||||
if fp16_statistics is not None:
|
||||
module._parameters[tensor_name].SCB = fp16_statistics.to(device)
|
||||
module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
|
||||
del fp16_statistics
|
||||
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
|
||||
if (
|
||||
@ -390,8 +402,9 @@ def set_module_tensor_to_device(
|
||||
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
|
||||
if not getattr(module.weight, "quant_state", None) and device_index is not None:
|
||||
module.weight = module.weight.cuda(device_index)
|
||||
|
||||
# clean pre and post forward hook
|
||||
if device != "cpu":
|
||||
if clear_cache and device != "cpu":
|
||||
clear_device_cache()
|
||||
|
||||
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
|
||||
@ -1594,6 +1607,14 @@ def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, tor
|
||||
model (`torch.nn.Module`): The model to check the device map against.
|
||||
device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
|
||||
"""
|
||||
all_module_names = dict(model.named_modules())
|
||||
invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
|
||||
|
||||
if invalid_keys:
|
||||
warnings.warn(
|
||||
f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
|
||||
)
|
||||
|
||||
all_model_tensors = [name for name, _ in model.state_dict().items()]
|
||||
for module_name in device_map.keys():
|
||||
if module_name == "":
|
||||
@ -2076,7 +2097,6 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
|
||||
DistributedType.MULTI_HPU,
|
||||
DistributedType.FSDP,
|
||||
DistributedType.XLA,
|
||||
DistributedType.TP,
|
||||
]:
|
||||
return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
|
||||
else:
|
||||
@ -2116,6 +2136,10 @@ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
|
||||
return torch.amp.GradScaler("hpu", **kwargs)
|
||||
elif is_xpu_available():
|
||||
return torch.amp.GradScaler("xpu", **kwargs)
|
||||
elif is_mps_available():
|
||||
if not is_torch_version(">=", "2.8.0"):
|
||||
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
|
||||
return torch.amp.GradScaler("mps", **kwargs)
|
||||
else:
|
||||
if is_torch_version(">=", "2.3"):
|
||||
return torch.amp.GradScaler("cuda", **kwargs)
|
||||
|
||||
@ -32,6 +32,7 @@ from .imports import (
|
||||
is_torch_distributed_available,
|
||||
is_torch_xla_available,
|
||||
)
|
||||
from .versions import is_torch_version
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@ -316,8 +317,8 @@ def _gpu_gather(tensor):
|
||||
state = PartialState()
|
||||
gather_op = torch.distributed.all_gather_into_tensor
|
||||
|
||||
# FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
|
||||
if state.device.type == "xpu":
|
||||
# NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
|
||||
if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
|
||||
torch.xpu.synchronize()
|
||||
|
||||
def _gpu_gather_one(tensor):
|
||||
@ -519,7 +520,7 @@ def gather_tensor_shape(tensor):
|
||||
|
||||
def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
|
||||
"""
|
||||
Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
|
||||
Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
|
||||
each worker doesn't need to know its shape when used (and tensor can be `None`)
|
||||
|
||||
Args:
|
||||
@ -731,7 +732,7 @@ def reduce(tensor, reduction="mean", scale=1.0):
|
||||
reduction (`str`, *optional*, defaults to `"mean"`):
|
||||
A reduction method. Can be of "mean", "sum", or "none"
|
||||
scale (`float`, *optional*):
|
||||
A default scaling value to be applied after the reduce, only valied on XLA.
|
||||
A default scaling value to be applied after the reduce, only valid on XLA.
|
||||
|
||||
Returns:
|
||||
The same data structure as `data` with all the tensors reduced.
|
||||
@ -787,7 +788,7 @@ def convert_to_fp32(tensor):
|
||||
|
||||
class ConvertOutputsToFp32:
|
||||
"""
|
||||
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
precision will be convert back to FP32.
|
||||
|
||||
Args:
|
||||
|
||||
@ -194,6 +194,26 @@ def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
|
||||
module.compile(**compile_kwargs)
|
||||
|
||||
|
||||
def model_has_dtensor(model: torch.nn.Module) -> bool:
|
||||
"""
|
||||
Check if the model has DTensor parameters.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to check.
|
||||
|
||||
Returns:
|
||||
`bool`: Whether the model has DTensor parameters.
|
||||
"""
|
||||
if is_torch_version(">=", "2.5.0"):
|
||||
from torch.distributed.tensor import DTensor
|
||||
else:
|
||||
# from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
return any(isinstance(p, DTensor) for p in model.parameters())
|
||||
|
||||
|
||||
def extract_model_from_parallel(
|
||||
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
|
||||
):
|
||||
|
||||
@ -16,7 +16,7 @@ from types import MethodType
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .imports import is_fp8_available, is_hpu_available
|
||||
from .imports import is_hpu_available, is_transformer_engine_available
|
||||
from .operations import GatheredParameters
|
||||
|
||||
|
||||
@ -27,11 +27,15 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
|
||||
"""
|
||||
Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
if not is_transformer_engine_available():
|
||||
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
|
||||
|
||||
if is_hpu_available():
|
||||
import intel_transformer_engine as te
|
||||
|
||||
if not hasattr(te, "LayerNorm"):
|
||||
# HPU does not have a LayerNorm implementation in TE
|
||||
te.LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
@ -56,9 +60,11 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
|
||||
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
|
||||
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
|
||||
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
|
||||
has_bias = module.bias is not None
|
||||
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
|
||||
te_module.weight.copy_(module.weight)
|
||||
te_module.bias.copy_(module.bias)
|
||||
if has_bias:
|
||||
te_module.bias.copy_(module.bias)
|
||||
|
||||
setattr(model, name, te_module)
|
||||
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
|
||||
@ -90,7 +96,7 @@ def has_transformer_engine_layers(model):
|
||||
"""
|
||||
Returns whether a given model has some `transformer_engine` layer or not.
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
if not is_transformer_engine_available():
|
||||
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
|
||||
|
||||
if is_hpu_available():
|
||||
@ -114,7 +120,7 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
|
||||
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
|
||||
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
if not is_transformer_engine_available():
|
||||
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
|
||||
|
||||
if is_hpu_available():
|
||||
@ -137,19 +143,39 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
|
||||
"""
|
||||
Applies FP8 context manager to the model's forward method
|
||||
"""
|
||||
if not is_fp8_available():
|
||||
if not is_transformer_engine_available():
|
||||
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
|
||||
|
||||
if is_hpu_available():
|
||||
import intel_transformer_engine.recipe as te_recipe
|
||||
|
||||
is_fp8_block_scaling_available = False
|
||||
message = "MXFP8 block scaling is not available on HPU."
|
||||
|
||||
else:
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
|
||||
|
||||
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
||||
if "fp8_format" in kwargs:
|
||||
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
|
||||
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
|
||||
|
||||
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
|
||||
raise ValueError(f"MXFP8 block scaling is not available: {message}")
|
||||
|
||||
if use_mxfp8_block_scaling:
|
||||
if "amax_compute_algo" in kwargs:
|
||||
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
|
||||
if "amax_history_len" in kwargs:
|
||||
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
|
||||
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
|
||||
else:
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
|
||||
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
|
||||
|
||||
if hasattr(model.forward, "__func__"):
|
||||
|
||||
240
tests/deepspeed/test_deepspeed_gradient_accumulation.py
Normal file
240
tests/deepspeed/test_deepspeed_gradient_accumulation.py
Normal file
@ -0,0 +1,240 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModel
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.test_utils.testing import AccelerateTestCase, require_deepspeed
|
||||
from accelerate.test_utils.training import RegressionDataset
|
||||
from accelerate.utils import patch_environment
|
||||
from accelerate.utils.dataclasses import DeepSpeedPlugin
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
GPT2_TINY = "hf-internal-testing/tiny-random-gpt2"
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
FP16 = "fp16"
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
class DeepSpeedGradientAccumulationTest(AccelerateTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self._test_file_path = inspect.getfile(self.__class__)
|
||||
path = Path(self._test_file_path).resolve()
|
||||
self.test_file_dir_str = str(path.parents[0])
|
||||
|
||||
self.ds_config_file = dict(
|
||||
zero2=f"{self.test_file_dir_str}/ds_config_zero2.json",
|
||||
zero3=f"{self.test_file_dir_str}/ds_config_zero3.json",
|
||||
)
|
||||
|
||||
# Load config files
|
||||
with open(self.ds_config_file[ZERO2], encoding="utf-8") as f:
|
||||
config_zero2 = json.load(f)
|
||||
with open(self.ds_config_file[ZERO3], encoding="utf-8") as f:
|
||||
config_zero3 = json.load(f)
|
||||
config_zero3["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = False
|
||||
|
||||
self.ds_config_dict = dict(zero2=config_zero2, zero3=config_zero3)
|
||||
|
||||
self.dist_env = dict(
|
||||
ACCELERATE_USE_DEEPSPEED="true",
|
||||
MASTER_ADDR="localhost",
|
||||
MASTER_PORT="10999",
|
||||
RANK="0",
|
||||
LOCAL_RANK="0",
|
||||
WORLD_SIZE="1",
|
||||
)
|
||||
|
||||
def test_gradient_accumulation_boundary_integration(self):
|
||||
"""Test that gradient accumulation boundaries are automatically handled by DeepSpeed integration."""
|
||||
gradient_accumulation_steps = 4
|
||||
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=False,
|
||||
zero3_init_flag=False,
|
||||
)
|
||||
|
||||
with patch_environment(**self.dist_env):
|
||||
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
|
||||
|
||||
# Setup simple training components
|
||||
train_set = RegressionDataset(length=80)
|
||||
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
# Test gradient accumulation with accumulate context manager
|
||||
batch_data = next(iter(train_dataloader))
|
||||
# Create proper input format for GPT2 model (RegressionDataset returns {"x": scalar, "y": scalar})
|
||||
# We need to create dummy input_ids for the GPT2 model
|
||||
batch_size = batch_data["x"].shape[0] if isinstance(batch_data["x"], torch.Tensor) else 1
|
||||
|
||||
# Create dummy input_ids for GPT2 model and move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device) # batch_size x sequence_length
|
||||
inputs = {"input_ids": input_ids}
|
||||
|
||||
# Track sync_gradients values to verify correct gradient accumulation behavior
|
||||
sync_values = []
|
||||
|
||||
# Simulate gradient accumulation steps
|
||||
for micro_step in range(gradient_accumulation_steps):
|
||||
with accelerator.accumulate(model):
|
||||
sync_values.append(accelerator.sync_gradients)
|
||||
outputs = model(**inputs)
|
||||
# Use the last hidden state and create a simple loss
|
||||
prediction = outputs.last_hidden_state.mean()
|
||||
loss = prediction.sum() # Simple scalar loss
|
||||
|
||||
# This should automatically handle gradient accumulation boundaries
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Verify gradient accumulation pattern was correct
|
||||
# Should be False for first 3 steps, True for the last step
|
||||
expected_sync = [False, False, False, True]
|
||||
self.assertEqual(sync_values, expected_sync)
|
||||
|
||||
# Reset step counter for accelerator
|
||||
accelerator.step = 0
|
||||
|
||||
def test_clip_grad_norm_returns_deepspeed_grad_norm(self):
|
||||
"""Test that clip_grad_norm_ works with DeepSpeed and returns gradient norm when available."""
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=False,
|
||||
zero3_init_flag=False,
|
||||
)
|
||||
|
||||
with patch_environment(**self.dist_env):
|
||||
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
|
||||
|
||||
# Setup simple model
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
# Create a simple dataloader for prepare to work
|
||||
train_set = RegressionDataset(length=16)
|
||||
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
|
||||
|
||||
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
||||
|
||||
# Perform a forward and backward pass to generate gradients
|
||||
batch_data = next(iter(train_dataloader))
|
||||
batch_size = len(batch_data["x"]) if isinstance(batch_data["x"], torch.Tensor) else 1
|
||||
|
||||
# Create dummy input_ids for GPT2 model and move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)
|
||||
inputs = {"input_ids": input_ids}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs)
|
||||
prediction = outputs.last_hidden_state.mean()
|
||||
loss = prediction.sum()
|
||||
|
||||
# Backward pass to generate gradients
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Test that gradient clipping works and returns a value
|
||||
grad_norm = accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
# After backward pass, we should get a valid gradient norm (either from DeepSpeed or fallback)
|
||||
self.assertIsInstance(grad_norm, (int, float, type(None)))
|
||||
if grad_norm is not None:
|
||||
self.assertGreaterEqual(grad_norm, 0.0)
|
||||
|
||||
def test_accelerator_backward_passes_sync_gradients(self):
|
||||
"""Test that Accelerator.backward() passes sync_gradients to DeepSpeed wrapper."""
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
gradient_accumulation_steps=2,
|
||||
gradient_clipping=1.0,
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="cpu",
|
||||
offload_param_device="cpu",
|
||||
zero3_save_16bit_model=False,
|
||||
zero3_init_flag=False,
|
||||
)
|
||||
|
||||
with patch_environment(**self.dist_env):
|
||||
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
|
||||
|
||||
# Setup simple model and data
|
||||
model = AutoModel.from_pretrained(GPT2_TINY)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
train_set = RegressionDataset(length=16)
|
||||
train_dataloader = DataLoader(train_set, batch_size=8, shuffle=True)
|
||||
|
||||
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
||||
|
||||
# Track sync_gradients values during backward calls
|
||||
sync_values = []
|
||||
|
||||
# Test two gradient accumulation steps
|
||||
batch_data = next(iter(train_dataloader))
|
||||
# Create proper input format for GPT2 model
|
||||
batch_size = len(batch_data["x"]) if isinstance(batch_data["x"], torch.Tensor) else 1
|
||||
|
||||
# Create dummy input_ids for GPT2 model and move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)
|
||||
inputs = {"input_ids": input_ids}
|
||||
|
||||
# First step - should have sync_gradients=False
|
||||
with accelerator.accumulate(model):
|
||||
sync_values.append(accelerator.sync_gradients)
|
||||
outputs = model(**inputs)
|
||||
prediction = outputs.last_hidden_state.mean()
|
||||
loss = prediction # Simple loss
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Second step - should have sync_gradients=True
|
||||
with accelerator.accumulate(model):
|
||||
sync_values.append(accelerator.sync_gradients)
|
||||
outputs = model(**inputs)
|
||||
prediction = outputs.last_hidden_state.mean()
|
||||
loss = prediction # Simple loss
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Verify sync_gradients pattern was correct
|
||||
self.assertEqual(len(sync_values), 2)
|
||||
self.assertFalse(sync_values[0]) # First step: not syncing
|
||||
self.assertTrue(sync_values[1]) # Second step: syncing
|
||||
@ -29,17 +29,15 @@ from accelerate.test_utils.testing import (
|
||||
get_launch_command,
|
||||
path_in_accelerate_package,
|
||||
require_fp16,
|
||||
require_fsdp2,
|
||||
require_multi_device,
|
||||
require_non_cpu,
|
||||
require_non_torch_xla,
|
||||
require_torch_min_version,
|
||||
run_first,
|
||||
slow,
|
||||
)
|
||||
from accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed
|
||||
from accelerate.utils.constants import (
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION,
|
||||
FSDP2_PYTORCH_VERSION,
|
||||
FSDP2_STATE_DICT_TYPE,
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_BACKWARD_PREFETCH,
|
||||
@ -48,7 +46,6 @@ from accelerate.utils.constants import (
|
||||
)
|
||||
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
|
||||
from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@ -65,10 +62,6 @@ if is_fp16_available():
|
||||
if is_bf16_available():
|
||||
dtypes.append(BF16)
|
||||
|
||||
FSDP_VERSIONS = [1]
|
||||
if is_torch_version(">=", FSDP2_PYTORCH_VERSION):
|
||||
FSDP_VERSIONS.append(2)
|
||||
|
||||
|
||||
@require_non_cpu
|
||||
@require_non_torch_xla
|
||||
@ -91,6 +84,7 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
1: self.fsdp1_env,
|
||||
2: self.fsdp2_env,
|
||||
}
|
||||
|
||||
self.current_fsdp_version = 1
|
||||
|
||||
def test_sharding_strategy(self):
|
||||
@ -322,6 +316,9 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
AcceleratorState._reset_state(True)
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
|
||||
assert plugin.mixed_precision_policy == mp_policy
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}
|
||||
@ -404,25 +401,26 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
assert fsdp_plugin.cpu_ram_efficient_loading is False
|
||||
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
|
||||
|
||||
def test_cp(self):
|
||||
if (fsdp_version := self.current_fsdp_version) != 2:
|
||||
return
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
for cp_comm_strategy in ["allgather", "alltoall"]:
|
||||
env["FSDP_CP_COMM_STRATEGY"] = cp_comm_strategy
|
||||
env["FSDP_CP_SIZE"] = "2"
|
||||
with patch_environment(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin()
|
||||
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
env["FSDP_CP_SIZE"] = "2"
|
||||
with patch_environment(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(cp_comm_strategy=cp_comm_strategy)
|
||||
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
|
||||
def test_ignored_modules_regex(self):
|
||||
# Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex
|
||||
env = self.fsdp_envs[1].copy()
|
||||
env["FSDP_IGNORED_MODULES"] = ".*\\.q_proj$"
|
||||
with patch_environment(**env):
|
||||
accelerator = Accelerator()
|
||||
model = AutoModel.from_pretrained(LLAMA_TESTING)
|
||||
model = accelerator.prepare(model)
|
||||
if self.current_fsdp_version == 1:
|
||||
# model has 2 layers
|
||||
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
|
||||
assert model._ignored_modules == layers_to_ignore
|
||||
else:
|
||||
params_to_ignore = {model.layers[0].self_attn.q_proj.weight, model.layers[1].self_attn.q_proj.weight}
|
||||
assert model._ignored_params == params_to_ignore
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
@require_non_cpu
|
||||
@require_non_torch_xla
|
||||
class FSDP2PluginIntegration(FSDPPluginIntegration):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -469,6 +467,7 @@ class FSDPIntegrationTest(TempDirTestCase):
|
||||
}
|
||||
self.n_train = 160
|
||||
self.n_val = 160
|
||||
|
||||
self.current_fsdp_version = 1
|
||||
|
||||
@require_fp16
|
||||
@ -624,24 +623,13 @@ class FSDPIntegrationTest(TempDirTestCase):
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd_config)
|
||||
|
||||
# TODO: Should probably be moved to a separate test file
|
||||
@require_torch_min_version(version=CONTEXT_PARALLEL_PYTORCH_VERSION)
|
||||
def test_dist_dataloader(self):
|
||||
if (fsdp_version := self.current_fsdp_version) != 2:
|
||||
return
|
||||
|
||||
self.test_file_path = self.test_scripts_folder / "test_distributed_dataloader.py"
|
||||
cmd = get_launch_command(num_processes=2, num_machines=1, machine_rank=0, fsdp_version=fsdp_version)
|
||||
|
||||
cmd_config = cmd.copy()
|
||||
cmd_config.extend(["--use_fsdp", "--fsdp_cp_size=2"])
|
||||
|
||||
cmd_config.append(self.test_file_path)
|
||||
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd_config)
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
@run_first
|
||||
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
|
||||
@require_non_torch_xla
|
||||
@require_multi_device
|
||||
@slow
|
||||
class FSDP2IntegrationTest(FSDPIntegrationTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@ -17,6 +17,7 @@ import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import time
|
||||
from unittest import skip
|
||||
from unittest.mock import patch
|
||||
|
||||
import psutil
|
||||
@ -478,6 +479,7 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
@require_cuda_or_xpu
|
||||
@slow
|
||||
@require_bnb
|
||||
@skip("Passing locally but not on CI. Also no one will try to train an offloaded bnb model")
|
||||
def test_accelerator_bnb_cpu_error(self):
|
||||
"""Tests that the accelerator can be used with the BNB library. This should fail as we are trying to load a model
|
||||
that is loaded between cpu and gpu"""
|
||||
|
||||
@ -625,7 +625,7 @@ class ToFSDP2Tester(unittest.TestCase):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
to_fsdp2_command(args)
|
||||
|
||||
assert "Config already specfies FSDP2, skipping conversion..." in cm.output[0]
|
||||
assert "Config already specifies FSDP2, skipping conversion..." in cm.output[0]
|
||||
|
||||
# Has to be the last test because it overwrites the config file
|
||||
def test_fsdp2_overwrite(self):
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from unittest import skip
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
@ -34,8 +35,8 @@ else:
|
||||
backend = "inductor"
|
||||
|
||||
|
||||
@require_non_hpu
|
||||
@require_huggingface_suite
|
||||
@skip("Don't work with torch 2.8")
|
||||
class RegionalCompilationTester(unittest.TestCase):
|
||||
def _get_model_and_inputs(self):
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
@ -109,6 +110,7 @@ class RegionalCompilationTester(unittest.TestCase):
|
||||
release_memory(model, full_compilation_model, regional_compilation_model)
|
||||
|
||||
@slow
|
||||
@require_non_hpu
|
||||
@require_non_cpu
|
||||
@require_huggingface_suite
|
||||
def test_regional_compilation_inference_speedup(self):
|
||||
|
||||
@ -15,6 +15,7 @@ fsdp_config:
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_transformer_layer_cls_to_wrap: BertLayer
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_ignored_modules: null
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
|
||||
250
tests/test_dataclasses.py
Normal file
250
tests/test_dataclasses.py
Normal file
@ -0,0 +1,250 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import patch_environment
|
||||
from accelerate.utils.constants import (
|
||||
BETA_CP_AVAILABLE_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
|
||||
)
|
||||
from accelerate.utils.imports import is_transformers_available
|
||||
from accelerate.utils.versions import compare_versions, is_torch_version
|
||||
|
||||
|
||||
def _should_skip_cp_test(cp_size):
|
||||
"""Check if CP test should be skipped based on cp_size and torch version."""
|
||||
return cp_size > 1 and not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION)
|
||||
|
||||
|
||||
def _should_skip_tp_test(tp_size):
|
||||
"""Check if TP test should be skipped based on tp_size, torch version, and transformers availability."""
|
||||
if tp_size <= 1:
|
||||
return False
|
||||
|
||||
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
return True
|
||||
|
||||
if not is_transformers_available():
|
||||
return True
|
||||
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TestParallelismConfig:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_init_device_mesh(self):
|
||||
def mock_init_mesh(device_type, mesh_shape, mesh_dim_names):
|
||||
mesh = Mock()
|
||||
mesh.size.return_value = 1
|
||||
for dim in mesh_shape:
|
||||
mesh.size.return_value *= dim
|
||||
mesh.shape = mesh_shape
|
||||
mesh.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# mock device_mesh._flatten
|
||||
mesh.flattened_dims = []
|
||||
|
||||
def mock_getitem(key):
|
||||
submesh = Mock()
|
||||
|
||||
def mock_flatten(name):
|
||||
mesh.flattened_dims.append((key, name))
|
||||
|
||||
submesh._flatten = Mock(side_effect=mock_flatten)
|
||||
return submesh
|
||||
|
||||
mesh.__getitem__ = Mock(side_effect=mock_getitem)
|
||||
|
||||
return mesh
|
||||
|
||||
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
|
||||
yield mock_init_mesh
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
|
||||
[
|
||||
(8, 1, 1, 1, (8,), ("dp_replicate",)), # DDP
|
||||
(1, 8, 1, 1, (8,), ("dp_shard",)), # FSDP
|
||||
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")), # HSDP
|
||||
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")), # FSDP + TP
|
||||
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")), # HSDP + TP
|
||||
(1, 1, 8, 1, (8,), ("tp",)), # TP only
|
||||
(1, 1, 1, 4, (4,), ("cp",)), # CP only
|
||||
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")), # FSDP + CP
|
||||
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")), # FSDP + CP + TP
|
||||
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")), # HSDP + CP + TP
|
||||
],
|
||||
)
|
||||
def test_get_mesh(
|
||||
self,
|
||||
dp_replicate_size,
|
||||
dp_shard_size,
|
||||
tp_size,
|
||||
cp_size,
|
||||
expected_shape,
|
||||
expected_dim_names,
|
||||
):
|
||||
# Skip tests based on version requirements
|
||||
if _should_skip_cp_test(cp_size):
|
||||
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
|
||||
if _should_skip_tp_test(tp_size):
|
||||
pytest.skip(
|
||||
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
|
||||
)
|
||||
|
||||
config = ParallelismConfig(
|
||||
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
|
||||
)
|
||||
mesh_dim_names, mesh_shape = config._get_mesh()
|
||||
assert mesh_shape == expected_shape
|
||||
assert mesh_dim_names == expected_dim_names
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
|
||||
[
|
||||
(8, 1, 1, 1, (8,), ("dp_replicate",)),
|
||||
(1, 8, 1, 1, (8,), ("dp_shard",)),
|
||||
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")),
|
||||
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")),
|
||||
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")),
|
||||
(1, 1, 8, 1, (8,), ("tp",)),
|
||||
(1, 1, 1, 4, (4,), ("cp",)),
|
||||
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")),
|
||||
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")),
|
||||
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")),
|
||||
],
|
||||
)
|
||||
def test_build_device_mesh(
|
||||
self,
|
||||
dp_replicate_size,
|
||||
dp_shard_size,
|
||||
tp_size,
|
||||
cp_size,
|
||||
expected_shape,
|
||||
expected_dim_names,
|
||||
):
|
||||
"""Test build_device_mesh creates correct mesh and applies flattening."""
|
||||
# Skip tests based on version requirements
|
||||
if _should_skip_cp_test(cp_size):
|
||||
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
|
||||
if _should_skip_tp_test(tp_size):
|
||||
pytest.skip(
|
||||
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
|
||||
)
|
||||
|
||||
config = ParallelismConfig(
|
||||
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
|
||||
)
|
||||
device_mesh = config.build_device_mesh("cpu")
|
||||
|
||||
# Check mesh shape and dimension names match expected
|
||||
assert device_mesh.shape == expected_shape
|
||||
assert device_mesh.mesh_dim_names == expected_dim_names
|
||||
|
||||
# Check that correct flattening operations were called
|
||||
expected_flattened = []
|
||||
if config.dp_dim_names:
|
||||
expected_flattened.append((config.dp_dim_names, "dp"))
|
||||
if config.dp_shard_cp_dim_names:
|
||||
expected_flattened.append((config.dp_shard_cp_dim_names, "dp_shard_cp"))
|
||||
if config.dp_cp_dim_names:
|
||||
expected_flattened.append((config.dp_cp_dim_names, "dp_cp"))
|
||||
|
||||
assert device_mesh.flattened_dims == expected_flattened
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dp_replicate_size, dp_shard_size, tp_size, cp_size",
|
||||
[
|
||||
(8, 1, 1, 1),
|
||||
(1, 8, 1, 1),
|
||||
(2, 4, 1, 1),
|
||||
(1, 4, 2, 1),
|
||||
(2, 2, 2, 1),
|
||||
(1, 1, 8, 1),
|
||||
(1, 1, 1, 4),
|
||||
(1, 4, 1, 2),
|
||||
(1, 2, 2, 2),
|
||||
(2, 2, 2, 2),
|
||||
],
|
||||
)
|
||||
def test_from_env(
|
||||
self,
|
||||
dp_replicate_size,
|
||||
dp_shard_size,
|
||||
tp_size,
|
||||
cp_size,
|
||||
):
|
||||
if _should_skip_cp_test(cp_size):
|
||||
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
|
||||
if _should_skip_tp_test(tp_size):
|
||||
pytest.skip(
|
||||
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
|
||||
)
|
||||
|
||||
new_env = {
|
||||
"PARALLELISM_CONFIG_DP_REPLICATE_SIZE": dp_replicate_size,
|
||||
"PARALLELISM_CONFIG_DP_SHARD_SIZE": dp_shard_size,
|
||||
"PARALLELISM_CONFIG_TP_SIZE": tp_size,
|
||||
"PARALLELISM_CONFIG_CP_SIZE": cp_size,
|
||||
}
|
||||
|
||||
with patch_environment(**new_env):
|
||||
config = ParallelismConfig()
|
||||
for key, value in new_env.items():
|
||||
assert getattr(config, key.split("PARALLELISM_CONFIG_")[-1].lower()) == value
|
||||
|
||||
def test_cp_handler(self):
|
||||
"""Test CP handler with various configurations."""
|
||||
|
||||
# Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size
|
||||
if _should_skip_cp_test(2):
|
||||
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
|
||||
|
||||
from accelerate.utils import TorchContextParallelConfig
|
||||
|
||||
for setting in ("allgather", "alltoall"):
|
||||
cp_handler = TorchContextParallelConfig(cp_comm_strategy=setting)
|
||||
pc = ParallelismConfig(cp_size=2, cp_handler=cp_handler)
|
||||
|
||||
assert pc.cp_handler is not None, "CP handler should be set"
|
||||
assert pc.cp_handler.cp_comm_strategy == setting, (
|
||||
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
|
||||
)
|
||||
|
||||
for setting in ("allgather", "alltoall"):
|
||||
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
|
||||
pc = ParallelismConfig(cp_size=2)
|
||||
assert pc.cp_handler is not None, "CP handler should be set from environment"
|
||||
assert pc.cp_handler.cp_comm_strategy == setting, (
|
||||
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
|
||||
)
|
||||
|
||||
for setting in ("invalid", "unsupported"):
|
||||
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
|
||||
TorchContextParallelConfig(cp_comm_strategy=setting)
|
||||
|
||||
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
|
||||
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
|
||||
pc = ParallelismConfig(cp_size=2)
|
||||
|
||||
def test_tp_handler(self):
|
||||
assert True, "Tensor parallelism handler doesn't hold any logic yet"
|
||||
@ -19,7 +19,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from unittest import mock, skip
|
||||
|
||||
import torch
|
||||
|
||||
@ -239,7 +239,10 @@ class FeatureExamplesTests(TempDirTestCase):
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@require_trackers
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"WANDB_MODE": "offline", "DVCLIVE_TEST": "true", "SWANLAB_MODE": "offline"},
|
||||
)
|
||||
def test_tracking(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
testargs = f"""
|
||||
@ -294,12 +297,14 @@ class FeatureExamplesTests(TempDirTestCase):
|
||||
|
||||
@require_pippy
|
||||
@require_multi_device
|
||||
@skip("Will soon deprecate pippy")
|
||||
def test_pippy_examples_bert(self):
|
||||
testargs = ["examples/inference/pippy/bert.py"]
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@require_pippy
|
||||
@require_multi_device
|
||||
@skip("Will soon deprecate pippy")
|
||||
def test_pippy_examples_gpt2(self):
|
||||
testargs = ["examples/inference/pippy/gpt2.py"]
|
||||
run_command(self.launch_args + testargs)
|
||||
|
||||
@ -12,9 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
@ -27,24 +31,29 @@ from accelerate.test_utils import (
|
||||
require_multi_device,
|
||||
require_torchao,
|
||||
require_transformer_engine,
|
||||
require_transformer_engine_mxfp8,
|
||||
run_first,
|
||||
)
|
||||
from accelerate.test_utils.testing import require_deepspeed, run_command
|
||||
from accelerate.utils import (
|
||||
AORecipeKwargs,
|
||||
FP8RecipeKwargs,
|
||||
TERecipeKwargs,
|
||||
has_ao_layers,
|
||||
has_transformer_engine_layers,
|
||||
is_torchao_available,
|
||||
is_transformer_engine_available,
|
||||
)
|
||||
|
||||
|
||||
def can_convert_te_model():
|
||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
|
||||
def can_convert_te_model(from_config=False):
|
||||
if not from_config:
|
||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [TERecipeKwargs()]}
|
||||
else:
|
||||
accelerator_kwargs = {}
|
||||
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
assert accelerator.fp8_enabled, "FP8 is not enabled"
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
||||
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
|
||||
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.LayerNorm(32, bias=False), torch.nn.Linear(32, 16))
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||
|
||||
@ -58,10 +67,14 @@ def maintain_proper_deepspeed_config(expected_version):
|
||||
)
|
||||
|
||||
|
||||
def can_convert_ao_model():
|
||||
def can_convert_ao_model(from_config=False):
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
|
||||
if not from_config:
|
||||
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
|
||||
else:
|
||||
accelerator_kwargs = {}
|
||||
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
||||
@ -78,13 +91,51 @@ def can_convert_ao_model():
|
||||
class TestTransformerEngine(unittest.TestCase):
|
||||
def test_can_prepare_model_single_gpu(self):
|
||||
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||
run_command(command)
|
||||
|
||||
def test_can_prepare_model_single_gpu_from_config(self):
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "NO"
|
||||
num_processes: 1
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: TE
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
@require_transformer_engine_mxfp8
|
||||
def test_can_prepare_model_with_mxfp8_block_scaling(self):
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "NO"
|
||||
num_processes: 1
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: TE
|
||||
use_mxfp8_block_scaling: true
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
@require_multi_device
|
||||
def test_can_prepare_model_multi_gpu(self):
|
||||
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||
run_command(command)
|
||||
|
||||
@require_deepspeed
|
||||
@ -116,7 +167,36 @@ class TestTransformerEngine(unittest.TestCase):
|
||||
command = get_launch_command(
|
||||
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
||||
)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||
run_command(command)
|
||||
|
||||
@require_deepspeed
|
||||
@require_multi_device
|
||||
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
|
||||
os.environ["ZERO_STAGE"] = str(1)
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "DEEPSPEED"
|
||||
deepspeed_config:
|
||||
gradient_clipping: 1.0
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
deepspeed_multinode_launcher: standard
|
||||
num_processes: 2
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: TE
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
|
||||
@ -125,13 +205,31 @@ class TestTransformerEngine(unittest.TestCase):
|
||||
class TestTorchAO(unittest.TestCase):
|
||||
def test_can_prepare_model_single_accelerator(self):
|
||||
command = get_launch_command(num_processes=1, monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||
run_command(command)
|
||||
|
||||
def test_can_prepare_model_single_gpu_from_config(self):
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "NO"
|
||||
num_processes: 1
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: AO
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
@require_multi_device
|
||||
def test_can_prepare_model_multi_accelerator(self):
|
||||
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||
run_command(command)
|
||||
|
||||
@require_deepspeed
|
||||
@ -163,16 +261,26 @@ class TestTorchAO(unittest.TestCase):
|
||||
command = get_launch_command(
|
||||
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
|
||||
)
|
||||
command += ["-m", "tests.test_fp8"]
|
||||
command += ["-m", "tests.test_fp8", "--test_ao"]
|
||||
run_command(command)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TE suite
|
||||
if is_transformer_engine_available():
|
||||
can_convert_te_model()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--test_te", action="store_true", default=False)
|
||||
parser.add_argument("--test_ao", action="store_true", default=False)
|
||||
parser.add_argument("--from_config", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.test_te and not args.test_ao:
|
||||
raise ValueError("Must specify at least one of --test_te or --test_ao")
|
||||
|
||||
if args.test_te:
|
||||
can_convert_te_model(args.from_config)
|
||||
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
|
||||
|
||||
# AO suite
|
||||
if is_torchao_available():
|
||||
can_convert_ao_model()
|
||||
if args.test_ao:
|
||||
can_convert_ao_model(args.from_config)
|
||||
|
||||
@ -24,8 +24,10 @@ from torch.fx import symbolic_trace
|
||||
from accelerate.big_modeling import attach_layerwise_casting_hooks
|
||||
from accelerate.hooks import (
|
||||
AlignDevicesHook,
|
||||
CpuOffload,
|
||||
ModelHook,
|
||||
SequentialHook,
|
||||
UserCpuOffloadHook,
|
||||
add_hook_to_module,
|
||||
attach_align_device_hook,
|
||||
remove_hook_from_module,
|
||||
@ -457,3 +459,58 @@ class HooksModelTester(unittest.TestCase):
|
||||
|
||||
with torch.no_grad():
|
||||
_ = test_model(inputs)
|
||||
|
||||
def test_cpu_offload_hook_moves_model(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("CUDA not available for offload test.")
|
||||
|
||||
model = ModelForTest()
|
||||
gpu_device = torch.device("cuda:0")
|
||||
hook = CpuOffload(execution_device=gpu_device)
|
||||
add_hook_to_module(model, hook)
|
||||
|
||||
x = torch.randn(2, 3).to(gpu_device)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, gpu_device)
|
||||
|
||||
remove_hook_from_module(model)
|
||||
output2 = model(x)
|
||||
self.assertEqual(output2.device, gpu_device)
|
||||
|
||||
# should be on the gpu
|
||||
assert model.linear1.weight.device == gpu_device
|
||||
assert model.batchnorm.weight.device == gpu_device
|
||||
assert model.linear2.weight.device == gpu_device
|
||||
|
||||
def test_cpu_offload_hook_with_prev_module(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("CUDA not available for offload test.")
|
||||
|
||||
model1 = ModelForTest()
|
||||
model2 = ModelForTest()
|
||||
gpu_device = torch.device("cuda:0")
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
hook1 = CpuOffload(execution_device=gpu_device)
|
||||
add_hook_to_module(model1, hook1)
|
||||
user_hook1 = UserCpuOffloadHook(model1, hook1)
|
||||
|
||||
hook2 = CpuOffload(execution_device=gpu_device, prev_module_hook=user_hook1)
|
||||
add_hook_to_module(model2, hook2)
|
||||
|
||||
x = torch.randn(2, 3).to(gpu_device)
|
||||
output1 = model1(x)
|
||||
self.assertEqual(output1.device, gpu_device)
|
||||
|
||||
output2 = model2(x)
|
||||
self.assertEqual(output2.device, gpu_device)
|
||||
|
||||
# should be on the cpu
|
||||
assert model1.linear1.weight.device == cpu_device
|
||||
assert model1.batchnorm.weight.device == cpu_device
|
||||
assert model1.linear2.weight.device == cpu_device
|
||||
|
||||
# should be on the gpu still
|
||||
assert model2.linear1.weight.device == gpu_device
|
||||
assert model2.batchnorm.weight.device == gpu_device
|
||||
assert model2.linear2.weight.device == gpu_device
|
||||
|
||||
@ -64,6 +64,7 @@ class TestPrepareMultiGpuEnv(unittest.TestCase):
|
||||
num_cpu_threads_per_process=1,
|
||||
enable_cpu_affinity=False,
|
||||
same_network=False,
|
||||
use_parallelism_config=False,
|
||||
)
|
||||
|
||||
prepare_multi_gpu_env(args)
|
||||
|
||||
@ -61,7 +61,31 @@ class MemoryTest(unittest.TestCase):
|
||||
raise_fake_out_of_memory()
|
||||
|
||||
mock_training_loop_function()
|
||||
assert batch_sizes == [128, 64, 32, 16, 8]
|
||||
assert batch_sizes == [
|
||||
128,
|
||||
115,
|
||||
103,
|
||||
92,
|
||||
82,
|
||||
73,
|
||||
65,
|
||||
58,
|
||||
52,
|
||||
46,
|
||||
41,
|
||||
36,
|
||||
32,
|
||||
28,
|
||||
25,
|
||||
22,
|
||||
19,
|
||||
17,
|
||||
15,
|
||||
13,
|
||||
11,
|
||||
9,
|
||||
8,
|
||||
]
|
||||
|
||||
def test_memory_explicit(self):
|
||||
batch_sizes = []
|
||||
@ -75,7 +99,31 @@ class MemoryTest(unittest.TestCase):
|
||||
return batch_size, arg1
|
||||
|
||||
bs, arg1 = mock_training_loop_function("hello")
|
||||
assert batch_sizes == [128, 64, 32, 16, 8]
|
||||
assert batch_sizes == [
|
||||
128,
|
||||
115,
|
||||
103,
|
||||
92,
|
||||
82,
|
||||
73,
|
||||
65,
|
||||
58,
|
||||
52,
|
||||
46,
|
||||
41,
|
||||
36,
|
||||
32,
|
||||
28,
|
||||
25,
|
||||
22,
|
||||
19,
|
||||
17,
|
||||
15,
|
||||
13,
|
||||
11,
|
||||
9,
|
||||
8,
|
||||
]
|
||||
assert [bs, arg1] == [8, "hello"]
|
||||
|
||||
def test_start_zero(self):
|
||||
|
||||
@ -349,6 +349,26 @@ class ModelingUtilsTester(unittest.TestCase):
|
||||
|
||||
check_device_map(model, {"linear1": 0, "linear2": 1, "batchnorm": 1})
|
||||
|
||||
def test_check_device_map_invalid_keys(self):
|
||||
model = ModelForTest()
|
||||
|
||||
device_map = {
|
||||
"linear1": "cpu", # Valid module
|
||||
"batchnorm": "cpu", # Valid module
|
||||
"linear2": "cpu", # Valid module
|
||||
"invalid_module": 0, # Invalid - should trigger warning
|
||||
"another_invalid": 1, # Invalid - should trigger warning
|
||||
}
|
||||
|
||||
# Test for the warning about invalid keys
|
||||
with self.assertWarns(UserWarning) as cm:
|
||||
check_device_map(model, device_map)
|
||||
|
||||
warning_msg = str(cm.warning)
|
||||
self.assertIn("device_map keys do not match any submodules", warning_msg)
|
||||
self.assertIn("invalid_module", warning_msg)
|
||||
self.assertIn("another_invalid", warning_msg)
|
||||
|
||||
def shard_test_model(self, model, tmp_dir):
|
||||
module_index = {
|
||||
"linear1": "checkpoint_part1.bin",
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from unittest import skip
|
||||
|
||||
import torch
|
||||
|
||||
@ -28,7 +29,6 @@ from accelerate.test_utils import (
|
||||
path_in_accelerate_package,
|
||||
require_huggingface_suite,
|
||||
require_multi_device,
|
||||
require_non_hpu,
|
||||
require_non_torch_xla,
|
||||
require_pippy,
|
||||
require_torchvision,
|
||||
@ -70,7 +70,6 @@ class MultiDeviceTester(unittest.TestCase):
|
||||
execute_subprocess_async(cmd)
|
||||
|
||||
@run_first
|
||||
@require_non_hpu # Synapse detected a device critical error that requires a restart
|
||||
@require_multi_device
|
||||
def test_multi_device_merge_fsdp_weights(self):
|
||||
print(f"Found {device_count} {torch_device} devices.")
|
||||
@ -111,6 +110,7 @@ class MultiDeviceTester(unittest.TestCase):
|
||||
@require_torchvision
|
||||
@require_multi_device
|
||||
@require_huggingface_suite
|
||||
@skip("Will soon deprecate pippy")
|
||||
def test_pippy(self):
|
||||
"""
|
||||
Checks the integration with the pippy framework
|
||||
|
||||
@ -16,6 +16,7 @@ import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
@ -42,7 +43,9 @@ from accelerate.test_utils.testing import (
|
||||
require_matplotlib,
|
||||
require_mlflow,
|
||||
require_pandas,
|
||||
require_swanlab,
|
||||
require_tensorboard,
|
||||
require_trackio,
|
||||
require_wandb,
|
||||
skip,
|
||||
)
|
||||
@ -53,7 +56,9 @@ from accelerate.tracking import (
|
||||
DVCLiveTracker,
|
||||
GeneralTracker,
|
||||
MLflowTracker,
|
||||
SwanLabTracker,
|
||||
TensorBoardTracker,
|
||||
TrackioTracker,
|
||||
WandBTracker,
|
||||
)
|
||||
from accelerate.utils import (
|
||||
@ -520,6 +525,123 @@ class ClearMLTest(TempDirTestCase, MockingTestCase):
|
||||
self.assertCountEqual(plot["data"][0]["cells"]["values"], [[1, 2], [3, 4], [5, 6]])
|
||||
|
||||
|
||||
@require_swanlab
|
||||
@mock.patch.dict(os.environ, {"SWANLAB_MODE": "offline"})
|
||||
class SwanLabTrackingTest(TempDirTestCase, MockingTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Setting Path where SwanLab parsed log files are saved via the SWANLAB_LOG_DIR env var
|
||||
self.add_mocks(mock.patch.dict(os.environ, {"SWANLAB_LOG_DIR": self.tmpdir}))
|
||||
|
||||
@skip
|
||||
def test_swanlab(self):
|
||||
# Disable hardware monitoring to prevent errors in test mode.
|
||||
import swanlab
|
||||
from swanlab.log.backup import BackupHandler
|
||||
from swanlab.log.backup.datastore import DataStore
|
||||
from swanlab.log.backup.models import ModelsParser
|
||||
|
||||
swanlab.merge_settings(swanlab.Settings(hardware_monitor=False))
|
||||
# Start a fake training session.
|
||||
accelerator = Accelerator(log_with="swanlab")
|
||||
project_name = "test_project_with_config"
|
||||
experiment_name = "test"
|
||||
description = "test project for swanlab"
|
||||
tags = ["my_tag"]
|
||||
config = {
|
||||
"epochs": 10,
|
||||
"learning_rate": 0.01,
|
||||
"offset": 0.1,
|
||||
}
|
||||
kwargs = {
|
||||
"swanlab": {
|
||||
"experiment_name": experiment_name,
|
||||
"description": description,
|
||||
"tags": tags,
|
||||
}
|
||||
}
|
||||
accelerator.init_trackers(project_name, config, kwargs)
|
||||
record_metrics = []
|
||||
record_scalars = []
|
||||
record_images_count = 0
|
||||
record_logs = []
|
||||
for epoch in range(1, swanlab.config.epochs):
|
||||
acc = 1 - 2**-epoch - random.random() / epoch - 0.1
|
||||
loss = 2**-epoch + random.random() / epoch + 0.1
|
||||
ll = swanlab.log(
|
||||
{
|
||||
"accuracy": acc,
|
||||
"loss": loss,
|
||||
"image": swanlab.Image(np.random.random((3, 3, 3))),
|
||||
},
|
||||
step=epoch,
|
||||
)
|
||||
log = f"epoch={epoch}, accuracy={acc}, loss={loss}"
|
||||
print(log)
|
||||
record_scalars.extend([acc, loss])
|
||||
record_images_count += 1
|
||||
record_logs.append(log)
|
||||
record_metrics.extend([x for _, x in ll.items()])
|
||||
accelerator.end_training()
|
||||
|
||||
# Load latest offline log
|
||||
run_dir = swanlab.get_run().public.run_dir
|
||||
assert os.path.exists(run_dir) is True
|
||||
ds = DataStore()
|
||||
ds.open_for_scan(os.path.join(run_dir.__str__(), BackupHandler.BACKUP_FILE).__str__())
|
||||
with ModelsParser() as models_parser:
|
||||
for record in ds:
|
||||
if record is None:
|
||||
continue
|
||||
models_parser.parse_record(record)
|
||||
header, project, experiment, logs, runtime, columns, scalars, medias, footer = models_parser.get_parsed()
|
||||
|
||||
# test file header
|
||||
assert header.backup_type == "DEFAULT"
|
||||
|
||||
# test project info
|
||||
assert project.name == project_name
|
||||
assert project.workspace is None
|
||||
assert project.public is None
|
||||
|
||||
# test experiment info
|
||||
assert experiment.name is not None
|
||||
assert experiment.description == description
|
||||
assert experiment.tags == tags
|
||||
|
||||
# test log record
|
||||
backup_logs = [log.message for log in logs]
|
||||
for record_log in record_logs:
|
||||
assert record_log in backup_logs, "Log not found in backup logs: " + record_log
|
||||
|
||||
# test runtime info
|
||||
runtime_info = runtime.to_file_model(os.path.join(run_dir.__str__(), "files"))
|
||||
assert runtime_info.conda is None, "Not using conda, should be None"
|
||||
assert isinstance(runtime_info.requirements, str), "Requirements should be a string"
|
||||
assert isinstance(runtime_info.metadata, dict), "Metadata should be a dictionary"
|
||||
assert isinstance(runtime_info.config, dict), "Config should be a dictionary"
|
||||
for key in runtime_info.config:
|
||||
assert key in config, f"Config key {key} not found in original config"
|
||||
assert runtime_info.config[key]["value"] == config[key], (
|
||||
f"Config value for {key} does not match original value"
|
||||
)
|
||||
|
||||
# test scalar
|
||||
assert len(scalars) + len(medias) == len(record_metrics), "Total metrics count does not match"
|
||||
backup_scalars = [
|
||||
metric.metric["data"]
|
||||
for metric in record_metrics
|
||||
if metric.column_info.chart_type.value.column_type == "FLOAT"
|
||||
]
|
||||
assert len(backup_scalars) == len(scalars), "Total scalars count does not match"
|
||||
for scalar in backup_scalars:
|
||||
assert scalar in record_scalars, f"Scalar {scalar} not found in original scalars"
|
||||
backup_images = [
|
||||
metric for metric in record_metrics if metric.column_info.chart_type.value.column_type == "IMAGE"
|
||||
]
|
||||
assert len(backup_images) == record_images_count, "Total images count does not match"
|
||||
|
||||
|
||||
class MyCustomTracker(GeneralTracker):
|
||||
"Basic tracker that writes to a csv for testing"
|
||||
|
||||
@ -681,6 +803,15 @@ class TrackerDeferredInitializationTest(unittest.TestCase):
|
||||
_ = Accelerator(log_with=tracker)
|
||||
self.assertNotEqual(PartialState._shared_state, {})
|
||||
|
||||
@require_trackio
|
||||
def test_trackio_deferred_init(self):
|
||||
"""Test that trackio tracker initialization doesn't initialize distributed"""
|
||||
PartialState._reset_state()
|
||||
tracker = TrackioTracker(run_name="test_trackio")
|
||||
self.assertEqual(PartialState._shared_state, {})
|
||||
_ = Accelerator(log_with=tracker)
|
||||
self.assertNotEqual(PartialState._shared_state, {})
|
||||
|
||||
@require_comet_ml
|
||||
def test_comet_ml_deferred_init(self):
|
||||
"""Test that CometML tracker initialization doesn't initialize distributed"""
|
||||
@ -728,3 +859,12 @@ class TrackerDeferredInitializationTest(unittest.TestCase):
|
||||
self.assertEqual(PartialState._shared_state, {})
|
||||
_ = Accelerator(log_with=tracker)
|
||||
self.assertNotEqual(PartialState._shared_state, {})
|
||||
|
||||
@require_swanlab
|
||||
def test_swanlab_deferred_init(self):
|
||||
"""Test that SwanLab tracker initialization doesn't initialize distributed"""
|
||||
PartialState._reset_state()
|
||||
tracker = SwanLabTracker(run_name="test_swanlab")
|
||||
self.assertEqual(PartialState._shared_state, {})
|
||||
_ = Accelerator(log_with=tracker)
|
||||
self.assertNotEqual(PartialState._shared_state, {})
|
||||
|
||||
Reference in New Issue
Block a user