Compare commits

..

3 Commits

Author SHA1 Message Date
a3f8d23402 Merge updated examples 2025-08-23 15:42:00 +00:00
8ecadce10a Feat: cleanup 2025-08-23 15:34:55 +00:00
91985ab9d7 Feat: first version 2025-08-23 15:03:28 +00:00
39 changed files with 188 additions and 488 deletions

View File

@ -19,7 +19,7 @@ This guide will cover basics of using context parallelism in 🤗`accelerate`, f
## Why context parallelism?
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
@ -44,7 +44,7 @@ accelerator = Accelerator(
)
```
As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
As with any other feature in 🤗`accelerate`, you can enabled context parallelism also by passing the corresponding flags to `accelerate launch`.
In this case, it's no different:
```bash
@ -65,7 +65,7 @@ accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-st
> [!Warning]
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do your training loop, we provide a context manager than is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
You can use it as follows:
```python
@ -82,7 +82,7 @@ for batch in dataloader:
> [!Warning]
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
@ -120,7 +120,7 @@ The context manager takes a few arguments, that are used to configure the contex
## Configurable options
Accelerate provides only a single option to configure context parallelism (except for `cp_size`)
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
@ -142,7 +142,7 @@ We're going to be using word `shard` extensively in the following sections, so l
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
We can formalize this in the following pseudocode:
We can formalize this in a following pseudocode:
```python
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
@ -164,7 +164,7 @@ In ideal scenario, all-gather finishes in the exact moment as the calculation of
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
## How to choose the right rotation method?
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
You can directly see this issue in the profiler output in the image below:
<p align="center">

View File

@ -8,7 +8,7 @@ deepspeed_config:
# `transformers` uses the right `init` function
zero3_init_flag: false # true
# Finally we need to specify the number of accelerators to use
# Finally we need to specify the number of GPUs to use
num_processes: 2
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file

View File

@ -1,8 +1,8 @@
# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
distributed_type: FSDP
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
mixed_precision: 'bf16'
# Specify the number of accelerators to use
# Specify the number of GPUs to use
num_processes: 2
# Then we can specify the FSDP config
fsdp_config:

View File

@ -1,6 +0,0 @@
# Specify distributed_type as `MULTI_XPU` for DDP
distributed_type: "MULTI_XPU"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
# Specify the number of XPUs to use
num_processes: 2

View File

@ -1,4 +1,4 @@
# Since this is single GPU/XPU, we don't need distributed training
# Since this is single GPU, we don't need distributed training
distributed_type: "NO"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
mixed_precision: "bf16"

View File

@ -177,7 +177,6 @@ def training_function(config, args):
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
print(f"===== {predictions}")
metric.add_batch(
predictions=predictions,
references=references,

View File

@ -1,6 +1,6 @@
## Torch Native Parallelism
With recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
With recent versions of Torch, there has been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
This folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.
### ND Parallelism
@ -51,7 +51,7 @@ gaining even more speed and memory savings, as `ao` doesn't ship with any kernel
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS raise by using FP8.
<div style="display: flex; gap: 25px;">
<div style="text-align: center; width: 49%;">

View File

@ -31,7 +31,6 @@ from utils import (
PerformanceTracker,
create_collate_fn,
get_dataset,
get_model_flops_per_token,
setup_tokenizer,
)
@ -94,7 +93,7 @@ def train(args):
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["Qwen3DecoderLayer"],
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
)
@ -123,7 +122,6 @@ def train(args):
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
total_num_steps = min(args.num_steps, len(dataloader))
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
performance_tracker = PerformanceTracker(warmup_steps=5)
accelerator.print("Starting training...")
@ -134,10 +132,7 @@ def train(args):
loss = forward(model, batch, optimizer, accelerator)
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size,
model_flops_per_token=model_flops_per_token,
)
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from accelerate.utils import ParallelismConfig
@ -29,7 +28,7 @@ def parse_args():
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
parser.add_argument("--device-type", type=str, default="auto")
parser.add_argument("--device-type", type=str, default="cuda")
return parser.parse_args()
@ -39,9 +38,6 @@ def main():
pc = ParallelismConfig()
args = parse_args()
if args.device_type == "auto":
args.device_type = torch.accelerator.current_accelerator().type
model_kwargs = {}
if pc.tp_enabled:
model_kwargs["tp_plan"] = "auto"
@ -70,7 +66,7 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
processing_class=tokenizer,
tokenizer=tokenizer,
train_dataset=packed_dataset,
)

View File

@ -183,11 +183,7 @@ class PerformanceTracker:
return {}
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}"
if "tflops_per_device" in metrics:
print_msg += f" | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
else:
print_msg += "\n"
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f} | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
if with_memory:
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
@ -206,18 +202,16 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
def gpu_memory_usage_all(device=0):
device_type = torch.accelerator.current_accelerator().type
device = torch.device(f"{device_type}:{device}")
torch_device_module = getattr(torch, device_type, torch.cuda)
device = torch.device(f"cuda:{device}")
_BYTES_IN_GIB = 1024**3
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
torch_device_module.reset_peak_memory_stats(device)
torch.cuda.reset_peak_memory_stats(device)
return memory_stats

View File

@ -166,8 +166,6 @@ if is_torch_xla_available():
if is_npu_available(check_device=False):
import torch_npu # noqa: F401
if is_torch_version(">=", "2.6.0"):
from .dist_checkpointing import save_model_and_optimizer
try:
from torch.optim.lr_scheduler import LRScheduler
@ -570,18 +568,22 @@ class Accelerator:
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
raise ValueError(
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
)
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
if self.device.type not in (
"xpu",
"cuda",
"npu",
"xla",
"mlu",
"musa",
"hpu",
"sdaa",
) or is_torch_xla_available(check_is_tpu=True):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
if self.is_fsdp2:
self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
self.scaler = get_fsdp2_grad_scaler(**kwargs)
else:
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)
@ -593,10 +595,8 @@ class Accelerator:
self.native_amp = True
else:
self.native_amp = is_bf16_available(True)
if not self.native_amp and not is_torch_xla_available():
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
if self.native_amp and self.device.type == "mps" and not is_torch_version(">=", "2.6.0"):
raise ValueError("bf16 mixed precision with MPS device requires a Pytorch >= 2.6.0")
# for DeepSpeed, self.state.mixed_precision is always "bf16",
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
@ -1164,20 +1164,13 @@ class Accelerator:
>>> optimizer.zero_grad()
```
"""
if self.is_fsdp2:
model.set_requires_gradient_sync(False)
try:
yield
finally:
model.set_requires_gradient_sync(True)
else:
context = contextlib.nullcontext
if self.use_distributed:
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
context = getattr(model, "no_sync", context)
context = contextlib.nullcontext
if self.use_distributed:
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
context = getattr(model, "no_sync", context)
with context():
yield
with context():
yield
@staticmethod
@contextmanager
@ -1322,7 +1315,7 @@ class Accelerator:
<Tip warning={true}>
Overriding `even_batches` will not affect iterable-style data loaders.
Overidding `even_batches` will not affect iterable-style data loaders.
</Tip>
@ -1358,7 +1351,7 @@ class Accelerator:
if iterable_dl_seen:
warnings.warn(
"Overriding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
"Overridding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
)
else:
even_batches = self.even_batches
@ -1537,7 +1530,7 @@ class Accelerator:
and self.state.use_ipex
):
logger.warning(
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we encourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
)
args = self._prepare_ipex(*args)
if self.parallelism_config and self.parallelism_config.tp_enabled:
@ -1627,6 +1620,19 @@ class Accelerator:
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
try:
from torch.distributed.tensor.experimental._attention import (
create_cp_block_mask,
)
self._create_block_mask_fn = functools.partial(
create_cp_block_mask, device_mesh=self.torch_device_mesh["cp"]
)
except ImportError:
from torch.nn.attention.flex_attention import create_block_mask
self._create_block_mask_fn = create_block_mask
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@ -1667,7 +1673,7 @@ class Accelerator:
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
# Get old params and canonicalize - we canonicalize to have the mapping easy
# Get old params and canonicalize - we cannonicalize to have the mapping easy
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
# Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
@ -2883,7 +2889,7 @@ class Accelerator:
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
gradients = xm._fetch_gradients(opt)
# Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
# Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
# one by one in self.reduce is non-inplace.
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
@ -2948,7 +2954,7 @@ class Accelerator:
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> process_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
>>> process_tensor = torch.tensor([accelerator.process_index])
>>> gathered_tensor = accelerator.gather(process_tensor)
>>> gathered_tensor
tensor([0, 1, 2, 3])
@ -3042,7 +3048,7 @@ class Accelerator:
reduction (`str`, *optional*, defaults to "sum"):
A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.
scale (`float`, *optional*, defaults to 1.0):
A default scaling value to be applied after the reduce, only valid on XLA.
A default scaling value to be applied after the reduce, only valied on XLA.
Returns:
`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
@ -3334,7 +3340,7 @@ class Accelerator:
Arguments:
model: (`torch.nn.Module`):
Model to be saved. The model can be wrapped or unwrapped.
Model to be saved. The model can be wrapped or unwraped.
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
@ -3445,7 +3451,7 @@ class Accelerator:
`hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None`
The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weights`
The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weigths`
argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed
to [`Accelerator.load_state`].
@ -3536,18 +3542,10 @@ class Accelerator:
# Finish running the previous step before checkpointing
xm.mark_step()
# TODO: Siro - how to properly decide when to do this
_dist_save = self.parallelism_config is not None and self.parallelism_config.total_size > 1
if _dist_save:
save_model_and_optimizer(self, self._models[0], self._optimizers[0], output_dir, True)
self.print("Finished saving asynchronous checkpoint.")
# Save the models taking care of FSDP and DeepSpeed nuances
weights = []
for i, model in enumerate(self._models):
if _dist_save:
pass
elif self.distributed_type == DistributedType.FSDP:
if self.distributed_type == DistributedType.FSDP:
logger.info("Saving FSDP model")
save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
logger.info(f"FSDP Model saved to output dir {output_dir}")
@ -3565,9 +3563,7 @@ class Accelerator:
# Save the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if _dist_save:
pass
elif self.distributed_type == DistributedType.FSDP:
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
@ -4035,13 +4031,14 @@ class Accelerator:
Example:
```python
>>> buffers = [batch["input_ids"], batch["position_ids"], batch["shift_labels"]]
>>> for batch in dataloader:
... with accelerator.maybe_context_parallel(
... buffers=[batch["input_ids"], batch["attention_mask"]],
... buffer_seq_dims=[1, 1],
... no_restore_buffers={batch["input_ids"]},
... buffers=buffers,
... buffer_seq_dims=[1, 1, 1],
... no_restore_buffers=set(buffers),
... ):
... outputs = model(batch)
... outputs = model(**batch)
... ...
```
"""
@ -4059,6 +4056,61 @@ class Accelerator:
)
yield
def create_block_mask(
self,
mask_mod,
B,
H,
Q_LEN,
KV_LEN,
):
"""
Creates a flex attention mask to use. If `parallelism_config.cp_size > 1`, the mask will
be sharded to use with context parallelism. If not, this falls back to default `create_block_mask`.
Arguments mimic the signature of `torch.nn.flex_attention.create_block_mask`.
Args:
mask_mod: Mask modifier function to use.
B: Batch size.
H: Number of query heads.
Q_LEN: Query length.
KV_LEN: Key/Value length.
Example:
```python
>>> def causal(_b, _h, q_idx, kv_idx):
... return q_idx >= kv_idx
>>> block_mask = accelerator.create_block_mask(causal, None, None, seq_len, seq_len)
>>> buffers = [batch["input_ids"], batch["position_ids"], batch["shift_labels"]]
>>> for batch in dataloader:
... with accelerator.maybe_context_parallel(
... buffers=buffers,
... buffer_seq_dims=[1, 1, 1],
... no_restore_buffers=set(buffers),
... ):
... batch["attention_mask"] = block_mask
... outputs = model(**batch)
... ...
"""
if self.parallelism_config.cp_enabled:
logger.warning_once(
"Using flex-attention together with context parallel is highly experimental. You might encounter numerical issues, or crashes."
)
mask = self._create_block_mask_fn(
mask_mod=mask_mod,
B=B,
H=H,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
)
# We flag that this mask is created by us, so we don't remove it in pre-forward hook
mask._accelerate_created = True
return mask
@contextmanager
def autocast(self, autocast_handler: AutocastKwargs = None):
"""

View File

@ -767,7 +767,9 @@ def _attach_context_parallel_hooks(
"""
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
if "attention_mask" in module_kwargs:
if "attention_mask" in module_kwargs and not getattr(
module_kwargs["attention_mask"], "_accelerate_created", False
):
module_kwargs["attention_mask"] = None
module_kwargs["is_causal"] = True

View File

@ -60,4 +60,4 @@ def update_command_parser(parser, parents):
def update_config_command(args):
config_file = update_config(args)
print(f"Successfully updated the configuration file at {config_file}.")
print(f"Sucessfully updated the configuration file at {config_file}.")

View File

@ -493,13 +493,13 @@ def launch_command_parser(subparsers=None):
"--deepspeed_exclusion_filter",
default=None,
type=str,
help="DeepSpeed exclusion filter string when using multi-node setup.",
help="DeepSpeed exclusion filter string when using mutli-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_inclusion_filter",
default=None,
type=str,
help="DeepSpeed inclusion filter string when using multi-node setup.",
help="DeepSpeed inclusion filter string when using mutli-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_multinode_launcher",
@ -585,7 +585,7 @@ def launch_command_parser(subparsers=None):
"--fsdp_use_orig_params",
default="true",
type=str,
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
" (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(

View File

@ -89,7 +89,7 @@ def convert_config_to_fsdp2(config: dict) -> dict:
new_fsdp_config = {}
if fsdp_config.get("fsdp_version", 1) == 2:
logger.warning("Config already specifies FSDP2, skipping conversion...")
logger.warning("Config already specfies FSDP2, skipping conversion...")
logger.warning(
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
)

View File

@ -75,7 +75,7 @@ class SeedableRandomSampler(RandomSampler):
Same as a random sampler, except that in `__iter__` a seed can be used.
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
and be fully reproducible on multiple iterations.
and be fully reproducable on multiple iterations.
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
(stored in `self.epoch`).
@ -408,7 +408,7 @@ class DataLoaderStateMixin:
class DataLoaderAdapter:
"""
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
@ -451,8 +451,8 @@ class DataLoaderAdapter:
@property
def __class__(self):
"""
In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
"""
return self.base_dataloader.__class__
@ -763,12 +763,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# below code is for targeted support for dp, fsdp and tp
# below code is for targetted support for dp, fsdp and tp
# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behaviour not using device mesh should be sufficient
# otherwise the default behavour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
# different batches for each process irrespective of dp or fsdp
self.submesh_tp = None
@ -1063,7 +1063,7 @@ def prepare_data_loader(
ignored otherwise.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
reproducibility. Comes at a cost of potentially different performances due to different shuffling
reproducability. Comes at a cost of potentially different performances due to different shuffling
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
`self.set_epoch`
data_seed (`int`, *optional*, defaults to `None`):

View File

@ -1,189 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import queue
import dataclasses
import os
import pickle
import queue
from io import UnsupportedOperation
from pathlib import Path
from typing import TYPE_CHECKING, cast
import torch
import torch.distributed.checkpoint as dcp
import torch.distributed.checkpoint.state_dict as dcs
from torch.distributed.checkpoint.filesystem import (
FileSystemWriter,
SavePlan,
SavePlanner,
_generate_uuid,
_split_by_size_and_type,
)
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.storage import WriteResult
if TYPE_CHECKING:
from accelerate import Accelerator
class AccelerateStorageWriter(FileSystemWriter):
_DEFAULT_SUFFIX = ".distcp"
_OPTIM_FILE_PATH = "optimizer_0"
_MODEL_FILE_PATH = "pytorch_model_fsdp_0"
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH)
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH)
self.fs.mkdir(self.optim_path)
self.fs.mkdir(self.model_path)
return super().prepare_local_plan(plan)
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
):
storage_plan = plan.storage_data
optim_file_count = 0
model_file_count = 0
def gen_file(is_optimizer: bool = False) -> str:
nonlocal optim_file_count, model_file_count
if is_optimizer:
optim_file_count += 1
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}"
else:
model_file_count += 1
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}"
file_queue: queue.Queue = queue.Queue()
for bucket in _split_by_size_and_type(1, plan.items):
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn]
model_states = [wi for wi in bucket if "model" in wi.index.fqn]
for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]):
file_name = gen_file()
path = self.fs.concat_path(path, file_name)
file_queue.put((path, file_name, state))
return self._write_data(planner, file_queue)
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
try:
metadata = dataclasses.replace(metadata, version="1.0.0")
except TypeError:
pass
def _split_metadata(
metadata: Metadata,
) -> tuple[Metadata, Metadata]:
result = []
for to_get in ["model", "optim"]:
result.append(
Metadata(
state_dict_metadata={
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k
},
planner_data={
k.removeprefix("state."): tuple([x for x in v if x != "state"])
for k, v in metadata.planner_data.items()
if to_get in k
},
)
)
return tuple(result)
model_metadata, optim_metadata = _split_metadata(metadata)
model_storage_md, optim_storage_md = {}, {}
for wr_list in results:
for wr in wr_list:
new_index = dataclasses.asdict(wr.index)
new_index["fqn"] = new_index["fqn"].removeprefix("state.")
wr = WriteResult(
index=MetadataIndex(**new_index),
size_in_bytes=wr.size_in_bytes,
storage_data=wr.storage_data,
)
if "optim" in wr.index.fqn:
optim_storage_md.update({wr.index: wr.storage_data})
else:
model_storage_md.update({wr.index: wr.storage_data})
model_metadata.storage_data = model_storage_md
optim_metadata.storage_data = optim_storage_md
model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid())
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid())
tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp"))
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp"))
for meta, tmp_path, final_path in zip(
[model_metadata, optim_metadata],
[tmp_model_path, tmp_optim_path],
[self.model_path, self.optim_path],
):
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(meta, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except (AttributeError, UnsupportedOperation):
os.sync()
metadata_path = self.fs.concat_path(final_path, ".metadata")
if self.fs.exists(metadata_path):
self.fs.rm_file(metadata_path)
self.fs.rename(tmp_path, metadata_path)
def save_model_and_optimizer(
accelerator: "Accelerator",
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
save_path: str,
async_save: bool = False,
) -> None:
# async_save = False
if getattr(accelerator, "_async_save_handle", None) is not None:
accelerator._async_save_handle.result()
options = dcs.StateDictOptions()
import time
accelerator.print(f"{time.asctime()} - Preparing state dict...")
model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options)
accelerator.print(f"{time.asctime()} - Prepared state dict...")
accelerator.print(f"{time.asctime()} - Saving state dict...")
stateful = {
"model": model_sd,
"optimizer": optimizer_sd,
}
save_fn = dcp.save if not async_save else dcp.async_save
potential_handle = dcp.async_save(
state_dict=stateful,
storage_writer=AccelerateStorageWriter(save_path),
)
accelerator.print(f"{time.asctime()} - Finished saving state dict...")
if async_save:
accelerator._async_save_handle = potential_handle

View File

@ -17,8 +17,9 @@ import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from torch.distributed.device_mesh import init_device_mesh
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
from accelerate.utils.versions import is_torch_version
if TYPE_CHECKING:
@ -79,19 +80,6 @@ class ParallelismConfig:
f"\tcp_handler={self.cp_handler})\n"
)
def to_json(self):
import copy
_non_serializable_fields = ["device_mesh"]
copy.deepcopy(
{
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
for k, v in self.__dict__.items()
if k not in _non_serializable_fields
}
)
@property
def dp_dim_names(self):
"""Names of enabled dimensions across which data parallelism is applied."""
@ -190,11 +178,6 @@ class ParallelismConfig:
Args:
device_type (`str`): The type of device for which to build the mesh, e
"""
if is_torch_version(">=", "2.2.0"):
from torch.distributed.device_mesh import init_device_mesh
else:
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
mesh = self._get_mesh()
if len(mesh) == 0:
return None

View File

@ -195,7 +195,7 @@ class PartialState:
original_backend = kwargs.pop("backend", None)
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
if original_backend is not None and backend != original_backend:
raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
self.backend = backend
self.distributed_type = distributed_type
use_deepspeed = False
@ -230,7 +230,6 @@ class PartialState:
and (
os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
or True
)
):
self.backend = "cuda:nccl,cpu:gloo"
@ -401,7 +400,7 @@ class PartialState:
DistributedType.DEEPSPEED,
DistributedType.FSDP,
):
torch.distributed.barrier(device_ids=[self.local_process_index])
torch.distributed.barrier()
elif self.distributed_type == DistributedType.XLA:
xm.rendezvous("accelerate.utils.wait_for_everyone")
@ -950,13 +949,8 @@ class AcceleratorState:
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
"before using any functionality from the `accelerate` library."
)
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
# if we're using fp8.
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
self._mixed_precision = "no"
else:
self._mixed_precision = mixed_precision
# deepspeed handles mixed_precision using deepspeed_config
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
if mixed_precision == "bf16":
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
@ -1062,7 +1056,7 @@ class AcceleratorState:
@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
if self.distributed_type == DistributedType.DEEPSPEED:
config = self.deepspeed_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
@ -1085,7 +1079,7 @@ class AcceleratorState:
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
"""
PartialState().destroy_process_group(group)

View File

@ -53,7 +53,6 @@ from .testing import (
require_torchvision,
require_tpu,
require_transformer_engine,
require_transformer_engine_mxfp8,
require_xpu,
run_first,
skip,

View File

@ -69,7 +69,7 @@ class TorchTracemalloc:
self.begin = torch.npu.memory_allocated()
elif is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.xpu.memory_allocated()
elif is_hpu_available():
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process

View File

@ -50,7 +50,7 @@ def test_gather_object(state):
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
def test_gather_non_contiguous(state):
def test_gather_non_contigous(state):
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
if state.distributed_type == DistributedType.XLA:
return
@ -160,8 +160,8 @@ def main():
test_gather(state)
state.print("testing gather_object")
test_gather_object(state)
state.print("testing gather non-contiguous")
test_gather_non_contiguous(state)
state.print("testing gather non-contigous")
test_gather_non_contigous(state)
state.print("testing broadcast")
test_broadcast(state)
state.print("testing pad_across_processes")

View File

@ -35,12 +35,10 @@ from accelerate.utils import (
gather,
gather_object,
is_bf16_available,
is_cuda_available,
is_datasets_available,
is_fp16_available,
is_hpu_available,
is_ipex_available,
is_mps_available,
is_pytest_available,
is_xpu_available,
set_seed,
@ -536,7 +534,7 @@ def training_check(use_seedable_sampler=False):
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
# FP32 wrapper check
if is_cuda_available() or is_mps_available():
if torch.cuda.is_available():
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
print("Keep fp32 wrapper check.")
AcceleratorState._reset_state()

View File

@ -71,7 +71,6 @@ from ..utils import (
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
@ -541,16 +540,6 @@ def require_transformer_engine(test_case):
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
def require_transformer_engine_mxfp8(test_case):
"""
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
when transformers engine MXFP8 block scaling isn't available
"""
return unittest.skipUnless(
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
)(test_case)
def require_torchao(test_case):
"""
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
@ -598,7 +587,7 @@ def require_torchdata_stateful_dataloader(test_case):
def run_first(test_case):
"""
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
guaranteed to run first.
garanteed to run first.
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
@ -617,7 +606,7 @@ def run_first(test_case):
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
data at the start of a test, and then destroys it at the end of the TestCase.
data at the start of a test, and then destroyes it at the end of the TestCase.
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases

View File

@ -111,7 +111,7 @@ class GeneralTracker:
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
tracking mechanism used by a tracker class (such as the `run` for wandb)
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
other functions should occur on the main process or across all processes (by default will use `True`)
"""

View File

@ -134,7 +134,6 @@ from .imports import (
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,

View File

@ -314,7 +314,7 @@ def _replace_with_bnb_layers(
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
import bitsandbytes as bnb

View File

@ -371,7 +371,6 @@ class TERecipeKwargs(KwargsHandler):
amax_history_len: int = None
amax_compute_algo: AmaxComputeAlgorithm = None
override_linear_precision: tuple[bool, bool, bool] = None
use_mxfp8_block_scaling: bool = None
def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
@ -400,8 +399,6 @@ class TERecipeKwargs(KwargsHandler):
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
self.override_linear_precision = (fprop, dgrad, wgrad)
if self.use_mxfp8_block_scaling is None:
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
@dataclass
@ -683,7 +680,7 @@ class DynamoBackend(str, BaseEnum):
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
- **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
more](https://github.com/intel/intel-extension-for-pytorch).
- **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)
- **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/)
- **HPU_BACKEND** -- Uses HPU backend for inference optimizations.
"""
@ -804,9 +801,9 @@ class DataLoaderConfiguration:
all workers.
use_seedable_sampler (`bool`, defaults to `False`):
Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducible using a different sampling technique. While seed-to-seed results
may differ, on average the differences are negligible when using multiple different seeds to compare.
Should also be ran with [`~utils.set_seed`] for the best results.
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
data_seed (`int`, defaults to `None`):
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
will use the current default seed from torch.
@ -849,8 +846,8 @@ class DataLoaderConfiguration:
default=False,
metadata={
"help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
"Ensures training results are fully reproducible using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are negligible when using"
"Ensures training results are fully reproducable using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are neglible when using"
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
@ -956,7 +953,7 @@ class GradientAccumulationPlugin(KwargsHandler):
sync_with_dataloader (`bool`, *optional*, defaults to `True`):
Whether to synchronize setting the gradients when at the end of the dataloader.
sync_each_batch (`bool`, *optional*):
Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory
Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory
requirements when using gradient accumulation with distributed training, at expense of speed.
Example:
@ -1553,12 +1550,10 @@ class FullyShardedDataParallelPlugin:
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
Backward prefetch strategy to use. Should be either a `str` or an instance of
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
`reduce_dtype`, and `buffer_dtype`.
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
@ -1637,7 +1632,6 @@ class FullyShardedDataParallelPlugin:
mixed_precision_policy: Optional[
Union[
dict,
str,
"torch.distributed.fsdp.MixedPrecision",
"torch.distributed.fsdp.MixedPrecisionPolicy",
]
@ -1929,11 +1923,7 @@ class FullyShardedDataParallelPlugin:
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
if isinstance(self.mixed_precision_policy, str):
# override is True since self.mixed_precision_policy is not None
# has to be overwritten with the correct mixed precision object
self.set_mixed_precision(self.mixed_precision_policy, override=True)
elif isinstance(self.mixed_precision_policy, dict):
if isinstance(self.mixed_precision_policy, dict):
self.set_mixed_precision(self.mixed_precision_policy)
if self.mixed_precision_policy is not None:
self.validate_mixed_precision_policy()
@ -2015,7 +2005,7 @@ class FullyShardedDataParallelPlugin:
def set_auto_wrap_policy(self, model):
"""
Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
`transformer_cls_to_wrap`
"""
from torch.distributed.fsdp.wrap import (
@ -2263,7 +2253,7 @@ class MegatronLMPlugin:
lr_warmup_fraction (`float`, defaults to `None`):
Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.
min_lr (`float`, defaults to `0`):
Minimum value for learning rate. The scheduler clip values below this threshold.
Minumum value for learning rate. The scheduler clip values below this threshold.
consumed_samples (`List`, defaults to `None`):
Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.
no_wd_decay_cond (`Optional`, defaults to `None`):
@ -2390,7 +2380,7 @@ class MegatronLMPlugin:
)
min_lr: float = field(
default=0,
metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."},
metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."},
)
consumed_samples: list[int] = field(
default=None,

View File

@ -149,7 +149,7 @@ def check_cuda_p2p_ib_support():
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
the 3090.
Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
"""
try:
device_names, device_count = get_gpu_info()

View File

@ -305,15 +305,13 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
optim_state = torch.load(input_optimizer_file, weights_only=True)
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
else:
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
ckpt_dir = (
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
if f"{OPTIMIZER_NAME}" not in input_dir
else input_dir
)
logger.info(f"Loading Optimizer from {ckpt_dir}")
optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)}
optim_state = {"optimizer": optimizer.state_dict()}
dist_cp.load(
optim_state,
checkpoint_id=ckpt_dir,
@ -650,7 +648,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device

View File

@ -114,14 +114,6 @@ def is_transformer_engine_available():
return _is_package_available("transformer_engine", "transformer-engine")
def is_transformer_engine_mxfp8_available():
if _is_package_available("transformer_engine", "transformer-engine"):
import transformer_engine.pytorch as te
return te.fp8.check_mxfp8_support()[0]
return False
def is_lomo_available():
return _is_package_available("lomo_optim")
@ -182,7 +174,7 @@ def is_bf16_available(ignore_tpu=False):
if is_xpu_available():
return torch.xpu.is_bf16_supported()
if is_mps_available():
return torch.backends.mps.is_macos_or_newer(14, 0)
return False
return True
@ -414,12 +406,7 @@ def is_npu_available(check_device=False):
if importlib.util.find_spec("torch_npu") is None:
return False
# NOTE: importing torch_npu may raise error in some envs
# e.g. inside cpu-only container with torch_npu installed
try:
import torch_npu # noqa: F401
except Exception:
return False
import torch_npu # noqa: F401
if check_device:
try:

View File

@ -873,7 +873,7 @@ def finish_mpu_init():
_set_random_seed(args.seed, args.data_parallel_random_init)
# initialize megatron setup
# intialize megatron setup
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
accelerator.print("Initializing Megatron-LM")
assert torch.cuda.is_available(), "Megatron requires CUDA."
@ -1344,7 +1344,7 @@ class MegatronEngine(torch.nn.Module):
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
# We need the sizes of these tensors for the broadcast
# We need the sizes of these tensors for the boradcast
sizes_list = [
prompts_tokens_tensor.size(0), # Batch size
prompts_tokens_tensor.size(1),
@ -1353,7 +1353,7 @@ class MegatronEngine(torch.nn.Module):
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
# Now that we have the sizes, we can broadcast the tokens
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)

View File

@ -2136,10 +2136,6 @@ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
return torch.amp.GradScaler("hpu", **kwargs)
elif is_xpu_available():
return torch.amp.GradScaler("xpu", **kwargs)
elif is_mps_available():
if not is_torch_version(">=", "2.8.0"):
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
return torch.amp.GradScaler("mps", **kwargs)
else:
if is_torch_version(">=", "2.3"):
return torch.amp.GradScaler("cuda", **kwargs)

View File

@ -32,7 +32,6 @@ from .imports import (
is_torch_distributed_available,
is_torch_xla_available,
)
from .versions import is_torch_version
if is_torch_xla_available():
@ -317,8 +316,8 @@ def _gpu_gather(tensor):
state = PartialState()
gather_op = torch.distributed.all_gather_into_tensor
# NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
# FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
if state.device.type == "xpu":
torch.xpu.synchronize()
def _gpu_gather_one(tensor):
@ -520,7 +519,7 @@ def gather_tensor_shape(tensor):
def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
"""
Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
each worker doesn't need to know its shape when used (and tensor can be `None`)
Args:
@ -732,7 +731,7 @@ def reduce(tensor, reduction="mean", scale=1.0):
reduction (`str`, *optional*, defaults to `"mean"`):
A reduction method. Can be of "mean", "sum", or "none"
scale (`float`, *optional*):
A default scaling value to be applied after the reduce, only valid on XLA.
A default scaling value to be applied after the reduce, only valied on XLA.
Returns:
The same data structure as `data` with all the tensors reduced.
@ -788,7 +787,7 @@ def convert_to_fp32(tensor):
class ConvertOutputsToFp32:
"""
Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
precision will be convert back to FP32.
Args:

View File

@ -148,34 +148,14 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
if is_hpu_available():
import intel_transformer_engine.recipe as te_recipe
is_fp8_block_scaling_available = False
message = "MXFP8 block scaling is not available on HPU."
else:
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
raise ValueError(f"MXFP8 block scaling is not available: {message}")
if use_mxfp8_block_scaling:
if "amax_compute_algo" in kwargs:
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
if "amax_history_len" in kwargs:
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
else:
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
if hasattr(model.forward, "__func__"):

View File

@ -316,9 +316,6 @@ class FSDPPluginIntegration(AccelerateTestCase):
AcceleratorState._reset_state(True)
env = self.fsdp_envs[fsdp_version].copy()
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
assert plugin.mixed_precision_policy == mp_policy
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}

View File

@ -625,7 +625,7 @@ class ToFSDP2Tester(unittest.TestCase):
with self.assertLogs(level="WARNING") as cm:
to_fsdp2_command(args)
assert "Config already specifies FSDP2, skipping conversion..." in cm.output[0]
assert "Config already specfies FSDP2, skipping conversion..." in cm.output[0]
# Has to be the last test because it overwrites the config file
def test_fsdp2_overwrite(self):

View File

@ -76,7 +76,7 @@ class TestParallelismConfig:
return mesh
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
yield mock_init_mesh
@pytest.mark.parametrize(

View File

@ -31,7 +31,6 @@ from accelerate.test_utils import (
require_multi_device,
require_torchao,
require_transformer_engine,
require_transformer_engine_mxfp8,
run_first,
)
from accelerate.test_utils.testing import require_deepspeed, run_command
@ -50,8 +49,6 @@ def can_convert_te_model(from_config=False):
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs)
assert accelerator.fp8_enabled, "FP8 is not enabled"
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.LayerNorm(32, bias=False), torch.nn.Linear(32, 16))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
@ -112,26 +109,6 @@ class TestTransformerEngine(unittest.TestCase):
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_transformer_engine_mxfp8
def test_can_prepare_model_with_mxfp8_block_scaling(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
use_mxfp8_block_scaling: true
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_multi_device
def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
@ -170,35 +147,6 @@ class TestTransformerEngine(unittest.TestCase):
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
@require_deepspeed
@require_multi_device
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
os.environ["ZERO_STAGE"] = str(1)
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "DEEPSPEED"
deepspeed_config:
gradient_clipping: 1.0
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 1
deepspeed_multinode_launcher: standard
num_processes: 2
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_torchao
@require_huggingface_suite