Feat: context parallel v2.0 (#3700)

* Cleanup: context parallel

* Feat: cleanup

* Feat: concept guide

* Fix: rename + version check

* Style

* Fix: add to namespace in a test

* Fix: add skip_if on dataclass tests

* Fix: proper version for version check

* Feat: add tests and cleanup

* Fix: properly version check added tests

* Feat: address comments

* Fix: add both shift_labels and labels to make the model.forward calculate loss

* Fix: remove import, improve comment

* Fix: final checks

* Fix: style

* Fix: style
This commit is contained in:
Matej Sirovatka
2025-08-05 16:17:13 +02:00
committed by GitHub
parent 24e48f3d20
commit 6891c57072
18 changed files with 683 additions and 218 deletions

View File

@ -92,6 +92,8 @@
title: FSDP vs DeepSpeed
- local: concept_guides/fsdp1_vs_fsdp2
title: FSDP1 vs FSDP2
- local: concept_guides/context_parallelism
title: Context parallelism
- local: concept_guides/low_precision_training
title: Low precision training methods
- local: concept_guides/training_tpu

View File

@ -0,0 +1,204 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Context Parallel in 🤗`accelerate`
This guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections.
## Why context parallelism?
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
## How to use context parallelism?
```diff
from accelerate.utils import ParallelismConfig, TorchContextParallelConfig
+ cp_config = TorchContextParallelConfig(
+ cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
+ )
+ parallelism_config = ParallelismConfig(
+ cp_size=8,
+ cp_handler=cp_config, # or just cp_size=8, if you want to use the default "allgather"
+ )
accelerator = Accelerator(
...,
parallelism_config=parallelism_config,
)
```
As with any other feature in 🤗`accelerate`, you can enabled context parallelism also by passing the corresponding flags to `accelerate launch`.
In this case, it's no different:
```bash
accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...
```
> [!Tip]
> You can also set the `cp_size` and `cp_comm_strategy` in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.
> [!Tip]
> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.
> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.
> [!Warning]
> Context parallelism is tightly coupled with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your
> program. If no `FSDP2` is used, error will be raised.
> [!Warning]
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do your training loop, we provide a context manager than is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
You can use it as follows:
```python
for batch in dataloader:
with accelerator.maybe_context_parallel(
buffers=[batch["input_ids"], batch["attention_mask"]],
buffer_seq_dims=[1, 1],
no_restore_buffers={batch["input_ids"], batch["labels"]},
):
outputs = model(**batch)
...
```
> [!Warning]
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
<br>
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
</p>
> [!Tip]
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
> ```bash
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
> ```
## Accelerate's interface
The context manager takes a few arguments, that are used to configure the context parallelism.
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.
as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.
> [!Warning]
> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.
> Imagine this case:
> labels = [l1, l2, l3, l4, ... li]
> if we apply context parallelism, each rank would end up with a part of labels, such as this:
> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...
> after transformers modelling code shifts the labels, it would end up with:
> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...
> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.
> Because of this, you need to manually shift the labels before passing them in the model
## Configurable options
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
You can see an end-to-end example in the [ND parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py) file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.
## Technical details
> [!Tip]
> This section is fairly technical, so if you don't need to learn the internals of context parallelism, you can skip it and start building 🚀
We're going to be using word `shard` extensively in the following sections, so let's define it first. If we call tensor `sharded` across `Dth` dimension, across `N` ranks, we mean that this tensor is split into `N` parts, where each part of the tensor has shape `[..., D//N, ...]`.
## So how does it work?
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
We can formalize this in a following pseudocode:
```python
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
attn[i] = attn(Qi, Ki, Vi)
for j in range(context_parallel_size):
Kj, Vj = comm_kernel()
attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]
final_attn = combine(attn)
```
## all-to-all vs all-gather
### all-gather
So what's the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention `attn_i` we launch an all-gather to gather all other `Ks` and `Vs` from all other ranks. As this communication is done, each rank has all the `Ks` and `Vs` from all other ranks, and can compute the attention with them sequentially.
In ideal scenario, all-gather finishes in the exact moment as the calculation of `attn_i` is done. However, this never happens in practice, so the ideal real overlap is achieved when the full `attn_i` is overlapped with a part of the communication, then to start the computation with `K_j` and `V_j`, we wait for the all-gather to finish.
### all-to-all
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
## How to choose the right rotation method?
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
You can directly see this issue in the profiler output in the image below:
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png" alt="all-to-all profiler output" />
<br>
<em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em>
</p>
## Why only FSDP2?
We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential.
How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.
This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
<br>
<em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em>
</p>
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.
## Data dispatching in joint mesh
We make sure to dispatch the same batch of data to the whole `cp` subgroup, so that the results are correct. (Meaning each rank in `cp` subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of `dp_shard` group.
Imagine it like this:
```
# 8 GPUS, --dp_shard_size 4, --cp_size 2
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
# GPUs 0,1 = batch 0
# GPUs 2,3 = batch 1
... and so on.
```

View File

@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--cp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=1024)
parser.add_argument("--num-steps", type=int, default=1000)
parser.add_argument("--save-dir", type=str, default="./outputs")
@ -52,17 +53,28 @@ def parse_args():
return parser.parse_args()
def forward(model, batch, optimizer, accelerator):
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
)
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
def forward(model, batch, optimizer, accelerator: Accelerator):
# We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
# itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
buffers = [batch["input_ids"], batch["shift_labels"], batch["labels"]]
with accelerator.maybe_context_parallel(
buffers=buffers, buffer_seq_dims=[1, 1, 1], no_restore_buffers=set(buffers)
):
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
# As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device
# I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group()
if accelerator.parallelism_config.dp_cp_dim_names
else None
)
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
return loss
@ -71,21 +83,21 @@ def train(args):
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
cp_size=args.cp_size,
)
# FSDP needs extra configuration, so we properly shard the model
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = None
if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
)
accelerator = Accelerator(
log_with=["wandb"],
mixed_precision="bf16",
parallelism_config=parallelism_config,
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
log_with=["wandb"], mixed_precision="bf16", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin
)
accelerator.init_trackers("nd_parallel_training")
@ -146,7 +158,7 @@ def train(args):
if __name__ == "__main__":
set_seed(42)
args = parse_args()
if args.dp_shard_size == 1:
if args.dp_shard_size == 1 and args.tp_size > 1:
# We currently don't support saving with `save_state` when using only
# tensor parallelism, fsdp must be enabled
warnings.warn(

View File

@ -1,187 +0,0 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.state import PartialState
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer
MODEL_ID = "NousResearch/Llama-3.2-1B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=128)
parser.add_argument("--model-save-dir", type=str, default="./outputs")
parser.add_argument(
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
)
parser.add_argument(
"--save-optimizer",
action="store_true",
default=False,
help="Whether to save the optimizer state after training.",
)
return parser.parse_args()
def main():
"""
Main function to train the model.
"""
args = parse_args()
set_seed(42)
if args.model:
model_id = args.model
else:
model_id = MODEL_ID
model_kwargs = {}
accelerator_kwargs = {}
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)
device_mesh = parallelism_config.build_device_mesh("cuda")
if args.tp_size > 1:
model_kwargs["tp_size"] = args.tp_size
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = device_mesh
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
reshard_after_forward=True,
activation_checkpointing=True,
state_dict_type="FULL_STATE_DICT",
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin
accelerator = Accelerator(
mixed_precision="no",
**accelerator_kwargs,
)
accelerator.print("Memory usage after model load")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
tokenizer = setup_tokenizer(model_id)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Memory usage after model prepare")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(100, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if parallelism_config.dp_shard_enabled and args.save_optimizer:
accelerator.print("Saving optimizer state...")
save_fsdp_optimizer(
fsdp2_plugin,
accelerator,
optimizer,
model,
args.model_save_dir + "/opt",
)
accelerator.print("Optimizer state saved.")
accelerator.print("Saving model state...")
if args.save_model:
model.save_pretrained(args.model_save_dir)
accelerator.print(f"Model saved to {args.model_save_dir}")
if __name__ == "__main__":
main()

View File

@ -69,7 +69,7 @@ def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
return {"input_ids": packed_input_ids, "shift_labels": packed_labels}
with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
@ -111,8 +111,8 @@ def create_collate_fn():
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}
return collate_fn

View File

@ -35,6 +35,7 @@ from huggingface_hub import split_torch_state_dict_into_shards
from accelerate.utils.dataclasses import FP8BackendType
from .big_modeling import _attach_context_parallel_hooks
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
@ -780,6 +781,10 @@ class Accelerator:
):
if parallelism_config is None:
if PartialState._shared_state != {} and PartialState().parallelism_config is not None:
if os.environ.get("ACCELERATE_USE_PARALLELISM_CONFIG", "false") == "true":
raise ValueError(
"Partial state contains a `parallelism_config` which is not None, but you configured `parallelism_config` from the `accelerate launch` CLI. We don't know which to use, please remove one of those configuration methods."
)
parallelism_config = PartialState().parallelism_config
else:
# TODO: Remove after deprecating tp_plugin
@ -1551,6 +1556,9 @@ class Accelerator:
if self.parallelism_config and self.parallelism_config.tp_enabled:
args = self._prepare_tp(*args)
if self.parallelism_config and self.parallelism_config.cp_enabled:
args = self._prepare_cp(*args)
if self.fp8_backend == FP8BackendType.TE:
args = self._prepare_te(*args)
elif self.fp8_backend == FP8BackendType.AO:
@ -1623,6 +1631,21 @@ class Accelerator:
return args
def _prepare_cp(self, *args):
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
return args
def _prepare_fsdp2(self, *args):
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
@ -3976,6 +3999,69 @@ class Accelerator:
raise ValueError(err)
self._custom_objects.extend(objects)
@contextmanager
def maybe_context_parallel(
self,
buffers: list[torch.Tensor] | None = None,
buffer_seq_dims: list[int] | None = None,
no_restore_buffers: set[torch.Tensor] | None = None,
):
"""
A context manager that enables context parallel training.
Args:
buffers (`list[torch.Tensor]`, `optional`):
Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels
or positional embedding buffers. This context manager will modify these buffers in-place, and after
exiting the context, the buffers will be restored to their original state. To avoid unnecessary
restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.
buffer_seq_dims (`list[int]`, `optional`):
Sequence dimensions of `buffers`.
no_restore_buffers (`set[torch.Tensor]`, `optional`):
This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be
restored after the context exits. These buffers will be then kept in sharded state.
<Tip warning={true}>
`context_parallel` is currently only supported together with FSDP2, and requires `parallelism_config.cp_size` >
1. If either of these conditions are not met, this context manager will have no effect, though to enable fewer
code changes it will not raise an Exception.
</Tip>
<Tip warning={true}>
This context manager has to be recreated with each training step, as shown in the example below.
</Tip>
Example:
```python
>>> for batch in dataloader:
... with accelerator.maybe_context_parallel(
... buffers=[batch["input_ids"], batch["attention_mask"]],
... buffer_seq_dims=[1, 1],
... no_restore_buffers={batch["input_ids"]},
... ):
... outputs = model(batch)
... ...
```
"""
# We don't need to check FSDP2 as parallelism_config does that for us
# Invariant: in this branch self._cp_context is set, as it was set by `self._prepare_cp`
if self.parallelism_config and self.parallelism_config.cp_enabled:
with self._cp_context(
buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers
):
yield
else:
logger.warning_once(
"Context parallel training is not enabled. This context manager will have no effect. "
"To enable it, set `parallelism_config.cp_size` > 1 in the `Accelerator` constructor."
)
yield
@contextmanager
def autocast(self, autocast_handler: AutocastKwargs = None):
"""

View File

@ -747,3 +747,43 @@ def _attach_layerwise_casting_hooks(
non_blocking,
_prefix=layer_name,
)
def _attach_context_parallel_hooks(
model: nn.Module,
):
"""
Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
not support attention masks. This function modifies the model in place.
Args:
model (`nn.Module`):
The model to attach the hooks to.
"""
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
if "attention_mask" in module_kwargs:
module_kwargs["attention_mask"] = None
module_kwargs["is_causal"] = True
return module_args, module_kwargs
for name, module in model.named_modules():
# We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
# Then these cases can happen:
# 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
# there's no attention mask kwarg -> hook is a no-op
# 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
# attention mask kwarg is passed -> hook will remove the attention mask and add
# `is_causal=True` kwarg, which either crashes the training or fixes it
# (training would crash anyway as attention mask isn't supported)
# 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
# a no-op as well
if name.endswith("self_attn"):
# we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)

View File

@ -505,6 +505,53 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
parallelism_config = {}
if fsdp_config.get("fsdp_version", 1) == 2:
use_parallelism_config = _ask_field(
"Do you want to use the parallelism config? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_parallelism_config:
prefix = "parallelism_config_"
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
"What is the data parallelism replicate size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
"What is the FSDP shard size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "tp_size"] = _ask_field(
"What is the tensor parallelism size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "cp_size"] = _ask_field(
"What is the context parallelism size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
if parallelism_config[prefix + "cp_size"] > 1:
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
"What is the compute parallelism communication strategy?",
["allgather", "alltoall"],
lambda x: ["allgather", "alltoall"][int(x)],
default=0,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(
@ -849,6 +896,7 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
parallelism_config=parallelism_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,

View File

@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for parallelism config
parallelism_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
@ -229,6 +231,8 @@ class ClusterConfig(BaseConfig):
self.mpirun_config = {}
if self.fp8_config is None:
self.fp8_config = {}
if self.parallelism_config is None:
self.parallelism_config = {}
return super().__post_init__()

View File

@ -269,6 +269,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_parallelism_config",
default=False,
action="store_true",
help="Whether to use the parallelism config to configure the N-d distributed training.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
@ -767,6 +773,45 @@ def launch_command_parser(subparsers=None):
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
)
# ParallelismConfig arguments
parallelism_config_args = parser.add_argument_group(
"ParallelismConfig Arguments",
"Arguments related to the ParallelismConfig used for distributed training.",
)
parallelism_config_args.add_argument(
"--parallelism_config_dp_replicate_size",
type=int,
default=1,
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_dp_shard_size",
type=int,
default=1,
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
)
parallelism_config_args.add_argument(
"--parallelism_config_tp_size",
type=int,
default=1,
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_cp_size",
type=int,
default=1,
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_cp_comm_strategy",
type=str,
default="allgather",
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
)
# Other arguments of the training scripts
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
@ -994,6 +1039,9 @@ def _validate_launch_command(args):
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
defaults = None
warned = []
mp_from_config_flag = False
@ -1027,6 +1075,7 @@ def _validate_launch_command(args):
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
args.use_parallelism_config = defaults.parallelism_config != {}
if args.gpu_ids is None:
if defaults.gpu_ids is not None:
args.gpu_ids = defaults.gpu_ids

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from torch.distributed.device_mesh import init_device_mesh
from accelerate.utils.dataclasses import TorchTensorParallelConfig
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
if TYPE_CHECKING:
@ -56,13 +57,14 @@ class ParallelismConfig:
"""
dp_replicate_size: int = 1
dp_shard_size: int = 1
tp_size: int = 1
cp_size: int = 1
dp_replicate_size: int = None
dp_shard_size: int = None
tp_size: int = None
cp_size: int = None
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
tp_handler: Union[None, TorchTensorParallelConfig] = None
cp_handler: Union[None, TorchContextParallelConfig] = None
def __repr__(self):
return (
@ -71,7 +73,9 @@ class ParallelismConfig:
f"\tdp_shard_size={self.dp_shard_size},\n"
f"\ttp_size={self.tp_size},\n"
f"\tcp_size={self.cp_size},\n"
f"\ttotal_size={self.total_size}\n)"
f"\ttotal_size={self.total_size}\n"
f"\ttp_handler={self.tp_handler},\n"
f"\tcp_handler={self.cp_handler})\n"
)
@property
@ -206,6 +210,23 @@ class ParallelismConfig:
def __post_init__(self):
# Basic size validation
if self.dp_replicate_size is None:
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
if self.dp_shard_size is None:
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
if self.tp_size is None:
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
if self.cp_size is None:
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
if self.tp_size > 1:
if self.tp_handler is None:
self.tp_handler = TorchTensorParallelConfig()
if self.cp_size > 1:
if self.cp_handler is None:
self.cp_handler = TorchContextParallelConfig()
if self.dp_replicate_size < 1:
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
if self.dp_shard_size < 1:

View File

@ -981,11 +981,27 @@ class AcceleratorState:
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
]:
if os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None:
self.distributed_type = DistributedType.FSDP
if self._mixed_precision != "no":
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
# TODO: Siro - remove when axolotl fixes their side
if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
raise ValueError(
"`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism, as we also shard the model across the device mesh to save more memory"
)
if (
self.parallelism_config is not None
and self.parallelism_config.cp_enabled
and fsdp_plugin.fsdp_version == 1
):
raise ValueError(
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
)
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None
) or (self.parallelism_config is not None and self.parallelism_config.cp_enabled):
self.distributed_type = DistributedType.FSDP
if self._mixed_precision != "no":
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
if os.environ.get(
"ACCELERATE_USE_MEGATRON_LM", "false"
).lower() == "true" and self.distributed_type not in [

View File

@ -61,6 +61,7 @@ from .dataclasses import (
SageMakerDistributedType,
TensorInformation,
TERecipeKwargs,
TorchContextParallelConfig,
TorchDynamoPlugin,
TorchTensorParallelConfig,
TorchTensorParallelPlugin,

View File

@ -51,7 +51,9 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

View File

@ -32,6 +32,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_a
import torch
from .constants import (
BETA_CP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
FSDP2_PYTORCH_VERSION,
@ -2146,6 +2147,33 @@ class TorchTensorParallelPlugin:
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
@dataclass
class TorchContextParallelConfig:
"""
This class holds the configuration for context parallelism in PyTorch.
"""
cp_comm_strategy: Optional[str] = field(
default=None,
metadata={
"help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'."
},
)
def __post_init__(self):
if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. "
"Please upgrade your PyTorch version."
)
if self.cp_comm_strategy is None:
self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather")
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
raise ValueError(
f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'."
)
@dataclass
class TorchTensorParallelConfig:
"""

View File

@ -349,6 +349,20 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
if args.enable_cpu_affinity:
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
if not args.use_parallelism_config:
return current_env
prefix = "PARALLELISM_CONFIG_"
if args.use_parallelism_config:
current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
if args.parallelism_config_cp_size > 1:
current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
return current_env

View File

@ -17,6 +17,36 @@ from unittest.mock import Mock, patch
import pytest
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import patch_environment
from accelerate.utils.constants import (
BETA_CP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
)
from accelerate.utils.imports import is_transformers_available
from accelerate.utils.versions import compare_versions, is_torch_version
def _should_skip_cp_test(cp_size):
"""Check if CP test should be skipped based on cp_size and torch version."""
return cp_size > 1 and not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION)
def _should_skip_tp_test(tp_size):
"""Check if TP test should be skipped based on tp_size, torch version, and transformers availability."""
if tp_size <= 1:
return False
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
return True
if not is_transformers_available():
return True
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
return True
return False
class TestParallelismConfig:
@ -73,6 +103,14 @@ class TestParallelismConfig:
expected_shape,
expected_dim_names,
):
# Skip tests based on version requirements
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
@ -105,6 +143,14 @@ class TestParallelismConfig:
expected_dim_names,
):
"""Test build_device_mesh creates correct mesh and applies flattening."""
# Skip tests based on version requirements
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
@ -124,3 +170,81 @@ class TestParallelismConfig:
expected_flattened.append((config.dp_cp_dim_names, "dp_cp"))
assert device_mesh.flattened_dims == expected_flattened
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size",
[
(8, 1, 1, 1),
(1, 8, 1, 1),
(2, 4, 1, 1),
(1, 4, 2, 1),
(2, 2, 2, 1),
(1, 1, 8, 1),
(1, 1, 1, 4),
(1, 4, 1, 2),
(1, 2, 2, 2),
(2, 2, 2, 2),
],
)
def test_from_env(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
):
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
new_env = {
"PARALLELISM_CONFIG_DP_REPLICATE_SIZE": dp_replicate_size,
"PARALLELISM_CONFIG_DP_SHARD_SIZE": dp_shard_size,
"PARALLELISM_CONFIG_TP_SIZE": tp_size,
"PARALLELISM_CONFIG_CP_SIZE": cp_size,
}
with patch_environment(**new_env):
config = ParallelismConfig()
for key, value in new_env.items():
assert getattr(config, key.split("PARALLELISM_CONFIG_")[-1].lower()) == value
def test_cp_handler(self):
"""Test CP handler with various configurations."""
# Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size
if _should_skip_cp_test(2):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
from accelerate.utils import TorchContextParallelConfig
for setting in ("allgather", "alltoall"):
cp_handler = TorchContextParallelConfig(cp_comm_strategy=setting)
pc = ParallelismConfig(cp_size=2, cp_handler=cp_handler)
assert pc.cp_handler is not None, "CP handler should be set"
assert pc.cp_handler.cp_comm_strategy == setting, (
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
)
for setting in ("allgather", "alltoall"):
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
pc = ParallelismConfig(cp_size=2)
assert pc.cp_handler is not None, "CP handler should be set from environment"
assert pc.cp_handler.cp_comm_strategy == setting, (
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
)
for setting in ("invalid", "unsupported"):
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
TorchContextParallelConfig(cp_comm_strategy=setting)
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
pc = ParallelismConfig(cp_size=2)
def test_tp_handler(self):
assert True, "Tensor parallelism handler doesn't hold any logic yet"

View File

@ -64,6 +64,7 @@ class TestPrepareMultiGpuEnv(unittest.TestCase):
num_cpu_threads_per_process=1,
enable_cpu_affinity=False,
same_network=False,
use_parallelism_config=False,
)
prepare_multi_gpu_env(args)