mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-18 16:44:39 +08:00
Compare commits
3 Commits
feat/async
...
context-pa
| Author | SHA1 | Date | |
|---|---|---|---|
| a3f8d23402 | |||
| 8ecadce10a | |||
| 91985ab9d7 |
@ -19,7 +19,7 @@ This guide will cover basics of using context parallelism in 🤗`accelerate`, f
|
||||
|
||||
## Why context parallelism?
|
||||
|
||||
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.
|
||||
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
|
||||
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
|
||||
|
||||
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
|
||||
@ -44,7 +44,7 @@ accelerator = Accelerator(
|
||||
)
|
||||
```
|
||||
|
||||
As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
|
||||
As with any other feature in 🤗`accelerate`, you can enabled context parallelism also by passing the corresponding flags to `accelerate launch`.
|
||||
In this case, it's no different:
|
||||
|
||||
```bash
|
||||
@ -65,7 +65,7 @@ accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-st
|
||||
> [!Warning]
|
||||
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
|
||||
|
||||
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
|
||||
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do your training loop, we provide a context manager than is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
|
||||
You can use it as follows:
|
||||
|
||||
```python
|
||||
@ -82,7 +82,7 @@ for batch in dataloader:
|
||||
> [!Warning]
|
||||
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
|
||||
|
||||
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.
|
||||
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
|
||||
@ -120,7 +120,7 @@ The context manager takes a few arguments, that are used to configure the contex
|
||||
|
||||
|
||||
## Configurable options
|
||||
Accelerate provides only a single option to configure context parallelism (except for `cp_size`)
|
||||
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)
|
||||
|
||||
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
|
||||
|
||||
@ -142,7 +142,7 @@ We're going to be using word `shard` extensively in the following sections, so l
|
||||
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
|
||||
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
|
||||
|
||||
We can formalize this in the following pseudocode:
|
||||
We can formalize this in a following pseudocode:
|
||||
```python
|
||||
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
|
||||
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
|
||||
@ -164,7 +164,7 @@ In ideal scenario, all-gather finishes in the exact moment as the calculation of
|
||||
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
|
||||
|
||||
## How to choose the right rotation method?
|
||||
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
|
||||
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
|
||||
|
||||
You can directly see this issue in the profiler output in the image below:
|
||||
<p align="center">
|
||||
|
||||
@ -8,7 +8,7 @@ deepspeed_config:
|
||||
# `transformers` uses the right `init` function
|
||||
zero3_init_flag: false # true
|
||||
|
||||
# Finally we need to specify the number of accelerators to use
|
||||
# Finally we need to specify the number of GPUs to use
|
||||
num_processes: 2
|
||||
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
|
||||
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP
|
||||
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
|
||||
distributed_type: FSDP
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
|
||||
mixed_precision: 'bf16'
|
||||
# Specify the number of accelerators to use
|
||||
# Specify the number of GPUs to use
|
||||
num_processes: 2
|
||||
# Then we can specify the FSDP config
|
||||
fsdp_config:
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
# Specify distributed_type as `MULTI_XPU` for DDP
|
||||
distributed_type: "MULTI_XPU"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
# Specify the number of XPUs to use
|
||||
num_processes: 2
|
||||
@ -1,4 +1,4 @@
|
||||
# Since this is single GPU/XPU, we don't need distributed training
|
||||
# Since this is single GPU, we don't need distributed training
|
||||
distributed_type: "NO"
|
||||
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
|
||||
mixed_precision: "bf16"
|
||||
mixed_precision: "bf16"
|
||||
@ -177,7 +177,6 @@ def training_function(config, args):
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||
print(f"===== {predictions}")
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
## Torch Native Parallelism
|
||||
|
||||
With recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
|
||||
With recent versions of Torch, there has been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
|
||||
This folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.
|
||||
|
||||
### ND Parallelism
|
||||
@ -51,7 +51,7 @@ gaining even more speed and memory savings, as `ao` doesn't ship with any kernel
|
||||
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
|
||||
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
|
||||
|
||||
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.
|
||||
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS raise by using FP8.
|
||||
|
||||
<div style="display: flex; gap: 25px;">
|
||||
<div style="text-align: center; width: 49%;">
|
||||
|
||||
@ -31,7 +31,6 @@ from utils import (
|
||||
PerformanceTracker,
|
||||
create_collate_fn,
|
||||
get_dataset,
|
||||
get_model_flops_per_token,
|
||||
setup_tokenizer,
|
||||
)
|
||||
|
||||
@ -94,7 +93,7 @@ def train(args):
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["Qwen3DecoderLayer"],
|
||||
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
|
||||
state_dict_type="SHARDED_STATE_DICT",
|
||||
)
|
||||
|
||||
@ -123,7 +122,6 @@ def train(args):
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
|
||||
performance_tracker = PerformanceTracker(warmup_steps=5)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
@ -134,10 +132,7 @@ def train(args):
|
||||
loss = forward(model, batch, optimizer, accelerator)
|
||||
|
||||
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
|
||||
metrics = performance_tracker.step(
|
||||
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size,
|
||||
model_flops_per_token=model_flops_per_token,
|
||||
)
|
||||
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
if "warmup_completed" in metrics:
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
||||
|
||||
from accelerate.utils import ParallelismConfig
|
||||
@ -29,7 +28,7 @@ def parse_args():
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
|
||||
parser.add_argument("--device-type", type=str, default="auto")
|
||||
parser.add_argument("--device-type", type=str, default="cuda")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -39,9 +38,6 @@ def main():
|
||||
pc = ParallelismConfig()
|
||||
args = parse_args()
|
||||
|
||||
if args.device_type == "auto":
|
||||
args.device_type = torch.accelerator.current_accelerator().type
|
||||
|
||||
model_kwargs = {}
|
||||
if pc.tp_enabled:
|
||||
model_kwargs["tp_plan"] = "auto"
|
||||
@ -70,7 +66,7 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=packed_dataset,
|
||||
)
|
||||
|
||||
|
||||
@ -183,11 +183,7 @@ class PerformanceTracker:
|
||||
return {}
|
||||
|
||||
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
|
||||
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}"
|
||||
if "tflops_per_device" in metrics:
|
||||
print_msg += f" | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
|
||||
else:
|
||||
print_msg += "\n"
|
||||
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f} | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
|
||||
if with_memory:
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
@ -206,18 +202,16 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
device_type = torch.accelerator.current_accelerator().type
|
||||
device = torch.device(f"{device_type}:{device}")
|
||||
torch_device_module = getattr(torch, device_type, torch.cuda)
|
||||
device = torch.device(f"cuda:{device}")
|
||||
_BYTES_IN_GIB = 1024**3
|
||||
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
memory_stats = {
|
||||
"peak_memory_active": peak_memory_active,
|
||||
"peak_memory_alloc": peak_memory_alloc,
|
||||
"peak_memory_reserved": peak_memory_reserved,
|
||||
}
|
||||
torch_device_module.reset_peak_memory_stats(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
return memory_stats
|
||||
|
||||
@ -166,8 +166,6 @@ if is_torch_xla_available():
|
||||
if is_npu_available(check_device=False):
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
if is_torch_version(">=", "2.6.0"):
|
||||
from .dist_checkpointing import save_model_and_optimizer
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
@ -570,18 +568,22 @@ class Accelerator:
|
||||
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
|
||||
):
|
||||
self.native_amp = True
|
||||
supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
|
||||
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
|
||||
raise ValueError(
|
||||
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
|
||||
)
|
||||
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
|
||||
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
|
||||
if self.device.type not in (
|
||||
"xpu",
|
||||
"cuda",
|
||||
"npu",
|
||||
"xla",
|
||||
"mlu",
|
||||
"musa",
|
||||
"hpu",
|
||||
"sdaa",
|
||||
) or is_torch_xla_available(check_is_tpu=True):
|
||||
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
|
||||
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
|
||||
|
||||
# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
|
||||
if self.is_fsdp2:
|
||||
self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
|
||||
self.scaler = get_fsdp2_grad_scaler(**kwargs)
|
||||
else:
|
||||
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)
|
||||
|
||||
@ -593,10 +595,8 @@ class Accelerator:
|
||||
self.native_amp = True
|
||||
else:
|
||||
self.native_amp = is_bf16_available(True)
|
||||
if not self.native_amp and not is_torch_xla_available():
|
||||
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
|
||||
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
|
||||
if self.native_amp and self.device.type == "mps" and not is_torch_version(">=", "2.6.0"):
|
||||
raise ValueError("bf16 mixed precision with MPS device requires a Pytorch >= 2.6.0")
|
||||
|
||||
# for DeepSpeed, self.state.mixed_precision is always "bf16",
|
||||
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
|
||||
@ -1164,20 +1164,13 @@ class Accelerator:
|
||||
>>> optimizer.zero_grad()
|
||||
```
|
||||
"""
|
||||
if self.is_fsdp2:
|
||||
model.set_requires_gradient_sync(False)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.set_requires_gradient_sync(True)
|
||||
else:
|
||||
context = contextlib.nullcontext
|
||||
if self.use_distributed:
|
||||
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
|
||||
context = getattr(model, "no_sync", context)
|
||||
context = contextlib.nullcontext
|
||||
if self.use_distributed:
|
||||
if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
|
||||
context = getattr(model, "no_sync", context)
|
||||
|
||||
with context():
|
||||
yield
|
||||
with context():
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
@ -1322,7 +1315,7 @@ class Accelerator:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Overriding `even_batches` will not affect iterable-style data loaders.
|
||||
Overidding `even_batches` will not affect iterable-style data loaders.
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -1358,7 +1351,7 @@ class Accelerator:
|
||||
|
||||
if iterable_dl_seen:
|
||||
warnings.warn(
|
||||
"Overriding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
|
||||
"Overridding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
|
||||
)
|
||||
else:
|
||||
even_batches = self.even_batches
|
||||
@ -1537,7 +1530,7 @@ class Accelerator:
|
||||
and self.state.use_ipex
|
||||
):
|
||||
logger.warning(
|
||||
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we encourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
|
||||
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
|
||||
)
|
||||
args = self._prepare_ipex(*args)
|
||||
if self.parallelism_config and self.parallelism_config.tp_enabled:
|
||||
@ -1627,6 +1620,19 @@ class Accelerator:
|
||||
|
||||
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
|
||||
|
||||
try:
|
||||
from torch.distributed.tensor.experimental._attention import (
|
||||
create_cp_block_mask,
|
||||
)
|
||||
|
||||
self._create_block_mask_fn = functools.partial(
|
||||
create_cp_block_mask, device_mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
except ImportError:
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
self._create_block_mask_fn = create_block_mask
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@ -1667,7 +1673,7 @@ class Accelerator:
|
||||
else:
|
||||
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
|
||||
|
||||
# Get old params and canonicalize - we canonicalize to have the mapping easy
|
||||
# Get old params and canonicalize - we cannonicalize to have the mapping easy
|
||||
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
|
||||
|
||||
# Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
|
||||
@ -2883,7 +2889,7 @@ class Accelerator:
|
||||
while isinstance(opt, AcceleratedOptimizer):
|
||||
opt = opt.optimizer
|
||||
gradients = xm._fetch_gradients(opt)
|
||||
# Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
|
||||
# Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
|
||||
# one by one in self.reduce is non-inplace.
|
||||
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
|
||||
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
|
||||
@ -2948,7 +2954,7 @@ class Accelerator:
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> process_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
|
||||
>>> process_tensor = torch.tensor([accelerator.process_index])
|
||||
>>> gathered_tensor = accelerator.gather(process_tensor)
|
||||
>>> gathered_tensor
|
||||
tensor([0, 1, 2, 3])
|
||||
@ -3042,7 +3048,7 @@ class Accelerator:
|
||||
reduction (`str`, *optional*, defaults to "sum"):
|
||||
A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.
|
||||
scale (`float`, *optional*, defaults to 1.0):
|
||||
A default scaling value to be applied after the reduce, only valid on XLA.
|
||||
A default scaling value to be applied after the reduce, only valied on XLA.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
|
||||
@ -3334,7 +3340,7 @@ class Accelerator:
|
||||
|
||||
Arguments:
|
||||
model: (`torch.nn.Module`):
|
||||
Model to be saved. The model can be wrapped or unwrapped.
|
||||
Model to be saved. The model can be wrapped or unwraped.
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
@ -3445,7 +3451,7 @@ class Accelerator:
|
||||
|
||||
`hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None`
|
||||
|
||||
The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weights`
|
||||
The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weigths`
|
||||
argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed
|
||||
to [`Accelerator.load_state`].
|
||||
|
||||
@ -3536,18 +3542,10 @@ class Accelerator:
|
||||
# Finish running the previous step before checkpointing
|
||||
xm.mark_step()
|
||||
|
||||
# TODO: Siro - how to properly decide when to do this
|
||||
_dist_save = self.parallelism_config is not None and self.parallelism_config.total_size > 1
|
||||
if _dist_save:
|
||||
save_model_and_optimizer(self, self._models[0], self._optimizers[0], output_dir, True)
|
||||
self.print("Finished saving asynchronous checkpoint.")
|
||||
|
||||
# Save the models taking care of FSDP and DeepSpeed nuances
|
||||
weights = []
|
||||
for i, model in enumerate(self._models):
|
||||
if _dist_save:
|
||||
pass
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
logger.info("Saving FSDP model")
|
||||
save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
|
||||
logger.info(f"FSDP Model saved to output dir {output_dir}")
|
||||
@ -3565,9 +3563,7 @@ class Accelerator:
|
||||
|
||||
# Save the optimizers taking care of FSDP and DeepSpeed nuances
|
||||
optimizers = []
|
||||
if _dist_save:
|
||||
pass
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
for i, opt in enumerate(self._optimizers):
|
||||
logger.info("Saving FSDP Optimizer")
|
||||
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
|
||||
@ -4035,13 +4031,14 @@ class Accelerator:
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> buffers = [batch["input_ids"], batch["position_ids"], batch["shift_labels"]]
|
||||
>>> for batch in dataloader:
|
||||
... with accelerator.maybe_context_parallel(
|
||||
... buffers=[batch["input_ids"], batch["attention_mask"]],
|
||||
... buffer_seq_dims=[1, 1],
|
||||
... no_restore_buffers={batch["input_ids"]},
|
||||
... buffers=buffers,
|
||||
... buffer_seq_dims=[1, 1, 1],
|
||||
... no_restore_buffers=set(buffers),
|
||||
... ):
|
||||
... outputs = model(batch)
|
||||
... outputs = model(**batch)
|
||||
... ...
|
||||
```
|
||||
"""
|
||||
@ -4059,6 +4056,61 @@ class Accelerator:
|
||||
)
|
||||
yield
|
||||
|
||||
def create_block_mask(
|
||||
self,
|
||||
mask_mod,
|
||||
B,
|
||||
H,
|
||||
Q_LEN,
|
||||
KV_LEN,
|
||||
):
|
||||
"""
|
||||
Creates a flex attention mask to use. If `parallelism_config.cp_size > 1`, the mask will
|
||||
be sharded to use with context parallelism. If not, this falls back to default `create_block_mask`.
|
||||
Arguments mimic the signature of `torch.nn.flex_attention.create_block_mask`.
|
||||
|
||||
Args:
|
||||
mask_mod: Mask modifier function to use.
|
||||
B: Batch size.
|
||||
H: Number of query heads.
|
||||
Q_LEN: Query length.
|
||||
KV_LEN: Key/Value length.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> def causal(_b, _h, q_idx, kv_idx):
|
||||
... return q_idx >= kv_idx
|
||||
|
||||
>>> block_mask = accelerator.create_block_mask(causal, None, None, seq_len, seq_len)
|
||||
|
||||
>>> buffers = [batch["input_ids"], batch["position_ids"], batch["shift_labels"]]
|
||||
>>> for batch in dataloader:
|
||||
... with accelerator.maybe_context_parallel(
|
||||
... buffers=buffers,
|
||||
... buffer_seq_dims=[1, 1, 1],
|
||||
... no_restore_buffers=set(buffers),
|
||||
... ):
|
||||
... batch["attention_mask"] = block_mask
|
||||
... outputs = model(**batch)
|
||||
... ...
|
||||
"""
|
||||
if self.parallelism_config.cp_enabled:
|
||||
logger.warning_once(
|
||||
"Using flex-attention together with context parallel is highly experimental. You might encounter numerical issues, or crashes."
|
||||
)
|
||||
|
||||
mask = self._create_block_mask_fn(
|
||||
mask_mod=mask_mod,
|
||||
B=B,
|
||||
H=H,
|
||||
Q_LEN=Q_LEN,
|
||||
KV_LEN=KV_LEN,
|
||||
)
|
||||
# We flag that this mask is created by us, so we don't remove it in pre-forward hook
|
||||
mask._accelerate_created = True
|
||||
return mask
|
||||
|
||||
@contextmanager
|
||||
def autocast(self, autocast_handler: AutocastKwargs = None):
|
||||
"""
|
||||
|
||||
@ -767,7 +767,9 @@ def _attach_context_parallel_hooks(
|
||||
"""
|
||||
|
||||
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
|
||||
if "attention_mask" in module_kwargs:
|
||||
if "attention_mask" in module_kwargs and not getattr(
|
||||
module_kwargs["attention_mask"], "_accelerate_created", False
|
||||
):
|
||||
module_kwargs["attention_mask"] = None
|
||||
module_kwargs["is_causal"] = True
|
||||
|
||||
|
||||
@ -60,4 +60,4 @@ def update_command_parser(parser, parents):
|
||||
|
||||
def update_config_command(args):
|
||||
config_file = update_config(args)
|
||||
print(f"Successfully updated the configuration file at {config_file}.")
|
||||
print(f"Sucessfully updated the configuration file at {config_file}.")
|
||||
|
||||
@ -493,13 +493,13 @@ def launch_command_parser(subparsers=None):
|
||||
"--deepspeed_exclusion_filter",
|
||||
default=None,
|
||||
type=str,
|
||||
help="DeepSpeed exclusion filter string when using multi-node setup.",
|
||||
help="DeepSpeed exclusion filter string when using mutli-node setup.",
|
||||
)
|
||||
deepspeed_args.add_argument(
|
||||
"--deepspeed_inclusion_filter",
|
||||
default=None,
|
||||
type=str,
|
||||
help="DeepSpeed inclusion filter string when using multi-node setup.",
|
||||
help="DeepSpeed inclusion filter string when using mutli-node setup.",
|
||||
)
|
||||
deepspeed_args.add_argument(
|
||||
"--deepspeed_multinode_launcher",
|
||||
@ -585,7 +585,7 @@ def launch_command_parser(subparsers=None):
|
||||
"--fsdp_use_orig_params",
|
||||
default="true",
|
||||
type=str,
|
||||
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
|
||||
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
|
||||
" (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
|
||||
@ -89,7 +89,7 @@ def convert_config_to_fsdp2(config: dict) -> dict:
|
||||
new_fsdp_config = {}
|
||||
|
||||
if fsdp_config.get("fsdp_version", 1) == 2:
|
||||
logger.warning("Config already specifies FSDP2, skipping conversion...")
|
||||
logger.warning("Config already specfies FSDP2, skipping conversion...")
|
||||
logger.warning(
|
||||
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
|
||||
)
|
||||
|
||||
@ -75,7 +75,7 @@ class SeedableRandomSampler(RandomSampler):
|
||||
Same as a random sampler, except that in `__iter__` a seed can be used.
|
||||
|
||||
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
|
||||
and be fully reproducible on multiple iterations.
|
||||
and be fully reproducable on multiple iterations.
|
||||
|
||||
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
|
||||
(stored in `self.epoch`).
|
||||
@ -408,7 +408,7 @@ class DataLoaderStateMixin:
|
||||
class DataLoaderAdapter:
|
||||
"""
|
||||
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
|
||||
compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
||||
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
|
||||
@ -451,8 +451,8 @@ class DataLoaderAdapter:
|
||||
@property
|
||||
def __class__(self):
|
||||
"""
|
||||
In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
|
||||
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
|
||||
object.
|
||||
"""
|
||||
return self.base_dataloader.__class__
|
||||
@ -763,12 +763,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
|
||||
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
|
||||
# device mesh may hold any number of dimensions, however,
|
||||
# below code is for targeted support for dp, fsdp and tp
|
||||
# below code is for targetted support for dp, fsdp and tp
|
||||
|
||||
# device mesh will be used only if there is tp involved
|
||||
# or any multi-dimensional parallelism involving tp
|
||||
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
|
||||
# otherwise the default behaviour not using device mesh should be sufficient
|
||||
# otherwise the default behavour not using device mesh should be sufficient
|
||||
# since multi dimensional parallelism devoid of tp would anyway need
|
||||
# different batches for each process irrespective of dp or fsdp
|
||||
self.submesh_tp = None
|
||||
@ -1063,7 +1063,7 @@ def prepare_data_loader(
|
||||
ignored otherwise.
|
||||
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
|
||||
reproducibility. Comes at a cost of potentially different performances due to different shuffling
|
||||
reproducability. Comes at a cost of potentially different performances due to different shuffling
|
||||
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
|
||||
`self.set_epoch`
|
||||
data_seed (`int`, *optional*, defaults to `None`):
|
||||
|
||||
@ -1,189 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.import queue
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
from io import UnsupportedOperation
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.distributed.checkpoint.state_dict as dcs
|
||||
from torch.distributed.checkpoint.filesystem import (
|
||||
FileSystemWriter,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
_generate_uuid,
|
||||
_split_by_size_and_type,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
class AccelerateStorageWriter(FileSystemWriter):
|
||||
_DEFAULT_SUFFIX = ".distcp"
|
||||
_OPTIM_FILE_PATH = "optimizer_0"
|
||||
_MODEL_FILE_PATH = "pytorch_model_fsdp_0"
|
||||
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH)
|
||||
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH)
|
||||
self.fs.mkdir(self.optim_path)
|
||||
self.fs.mkdir(self.model_path)
|
||||
return super().prepare_local_plan(plan)
|
||||
|
||||
def write_data(
|
||||
self,
|
||||
plan: SavePlan,
|
||||
planner: SavePlanner,
|
||||
):
|
||||
storage_plan = plan.storage_data
|
||||
optim_file_count = 0
|
||||
model_file_count = 0
|
||||
|
||||
def gen_file(is_optimizer: bool = False) -> str:
|
||||
nonlocal optim_file_count, model_file_count
|
||||
if is_optimizer:
|
||||
optim_file_count += 1
|
||||
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}"
|
||||
else:
|
||||
model_file_count += 1
|
||||
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}"
|
||||
|
||||
file_queue: queue.Queue = queue.Queue()
|
||||
|
||||
for bucket in _split_by_size_and_type(1, plan.items):
|
||||
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn]
|
||||
model_states = [wi for wi in bucket if "model" in wi.index.fqn]
|
||||
|
||||
for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]):
|
||||
file_name = gen_file()
|
||||
path = self.fs.concat_path(path, file_name)
|
||||
file_queue.put((path, file_name, state))
|
||||
|
||||
return self._write_data(planner, file_queue)
|
||||
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
try:
|
||||
metadata = dataclasses.replace(metadata, version="1.0.0")
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def _split_metadata(
|
||||
metadata: Metadata,
|
||||
) -> tuple[Metadata, Metadata]:
|
||||
result = []
|
||||
for to_get in ["model", "optim"]:
|
||||
result.append(
|
||||
Metadata(
|
||||
state_dict_metadata={
|
||||
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k
|
||||
},
|
||||
planner_data={
|
||||
k.removeprefix("state."): tuple([x for x in v if x != "state"])
|
||||
for k, v in metadata.planner_data.items()
|
||||
if to_get in k
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(result)
|
||||
|
||||
model_metadata, optim_metadata = _split_metadata(metadata)
|
||||
model_storage_md, optim_storage_md = {}, {}
|
||||
for wr_list in results:
|
||||
for wr in wr_list:
|
||||
new_index = dataclasses.asdict(wr.index)
|
||||
new_index["fqn"] = new_index["fqn"].removeprefix("state.")
|
||||
wr = WriteResult(
|
||||
index=MetadataIndex(**new_index),
|
||||
size_in_bytes=wr.size_in_bytes,
|
||||
storage_data=wr.storage_data,
|
||||
)
|
||||
if "optim" in wr.index.fqn:
|
||||
optim_storage_md.update({wr.index: wr.storage_data})
|
||||
else:
|
||||
model_storage_md.update({wr.index: wr.storage_data})
|
||||
|
||||
model_metadata.storage_data = model_storage_md
|
||||
optim_metadata.storage_data = optim_storage_md
|
||||
|
||||
model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid())
|
||||
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid())
|
||||
|
||||
tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp"))
|
||||
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp"))
|
||||
|
||||
for meta, tmp_path, final_path in zip(
|
||||
[model_metadata, optim_metadata],
|
||||
[tmp_model_path, tmp_optim_path],
|
||||
[self.model_path, self.optim_path],
|
||||
):
|
||||
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
|
||||
pickle.dump(meta, metadata_file)
|
||||
if self.sync_files:
|
||||
try:
|
||||
os.fsync(metadata_file.fileno())
|
||||
except (AttributeError, UnsupportedOperation):
|
||||
os.sync()
|
||||
|
||||
metadata_path = self.fs.concat_path(final_path, ".metadata")
|
||||
if self.fs.exists(metadata_path):
|
||||
self.fs.rm_file(metadata_path)
|
||||
|
||||
self.fs.rename(tmp_path, metadata_path)
|
||||
|
||||
|
||||
def save_model_and_optimizer(
|
||||
accelerator: "Accelerator",
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
save_path: str,
|
||||
async_save: bool = False,
|
||||
) -> None:
|
||||
# async_save = False
|
||||
if getattr(accelerator, "_async_save_handle", None) is not None:
|
||||
accelerator._async_save_handle.result()
|
||||
|
||||
options = dcs.StateDictOptions()
|
||||
|
||||
import time
|
||||
|
||||
accelerator.print(f"{time.asctime()} - Preparing state dict...")
|
||||
model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options)
|
||||
accelerator.print(f"{time.asctime()} - Prepared state dict...")
|
||||
|
||||
accelerator.print(f"{time.asctime()} - Saving state dict...")
|
||||
stateful = {
|
||||
"model": model_sd,
|
||||
"optimizer": optimizer_sd,
|
||||
}
|
||||
|
||||
save_fn = dcp.save if not async_save else dcp.async_save
|
||||
|
||||
potential_handle = dcp.async_save(
|
||||
state_dict=stateful,
|
||||
storage_writer=AccelerateStorageWriter(save_path),
|
||||
)
|
||||
accelerator.print(f"{time.asctime()} - Finished saving state dict...")
|
||||
|
||||
if async_save:
|
||||
accelerator._async_save_handle = potential_handle
|
||||
@ -17,8 +17,9 @@ import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -79,19 +80,6 @@ class ParallelismConfig:
|
||||
f"\tcp_handler={self.cp_handler})\n"
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
import copy
|
||||
|
||||
_non_serializable_fields = ["device_mesh"]
|
||||
|
||||
copy.deepcopy(
|
||||
{
|
||||
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in _non_serializable_fields
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def dp_dim_names(self):
|
||||
"""Names of enabled dimensions across which data parallelism is applied."""
|
||||
@ -190,11 +178,6 @@ class ParallelismConfig:
|
||||
Args:
|
||||
device_type (`str`): The type of device for which to build the mesh, e
|
||||
"""
|
||||
if is_torch_version(">=", "2.2.0"):
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
else:
|
||||
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
|
||||
|
||||
mesh = self._get_mesh()
|
||||
if len(mesh) == 0:
|
||||
return None
|
||||
|
||||
@ -195,7 +195,7 @@ class PartialState:
|
||||
original_backend = kwargs.pop("backend", None)
|
||||
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
|
||||
if original_backend is not None and backend != original_backend:
|
||||
raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
|
||||
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
|
||||
self.backend = backend
|
||||
self.distributed_type = distributed_type
|
||||
use_deepspeed = False
|
||||
@ -230,7 +230,6 @@ class PartialState:
|
||||
and (
|
||||
os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
|
||||
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
|
||||
or True
|
||||
)
|
||||
):
|
||||
self.backend = "cuda:nccl,cpu:gloo"
|
||||
@ -401,7 +400,7 @@ class PartialState:
|
||||
DistributedType.DEEPSPEED,
|
||||
DistributedType.FSDP,
|
||||
):
|
||||
torch.distributed.barrier(device_ids=[self.local_process_index])
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.XLA:
|
||||
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
||||
|
||||
@ -950,13 +949,8 @@ class AcceleratorState:
|
||||
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
|
||||
"before using any functionality from the `accelerate` library."
|
||||
)
|
||||
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
|
||||
# if we're using fp8.
|
||||
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
|
||||
self._mixed_precision = "no"
|
||||
else:
|
||||
self._mixed_precision = mixed_precision
|
||||
|
||||
# deepspeed handles mixed_precision using deepspeed_config
|
||||
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
|
||||
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
|
||||
if mixed_precision == "bf16":
|
||||
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
|
||||
@ -1062,7 +1056,7 @@ class AcceleratorState:
|
||||
|
||||
@property
|
||||
def mixed_precision(self):
|
||||
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
config = self.deepspeed_plugin.deepspeed_config
|
||||
if config.get("fp16", {}).get("enabled", False):
|
||||
mixed_precision = "fp16"
|
||||
@ -1085,7 +1079,7 @@ class AcceleratorState:
|
||||
"""
|
||||
Destroys the process group. If one is not specified, the default process group is destroyed.
|
||||
|
||||
If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
|
||||
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
|
||||
"""
|
||||
PartialState().destroy_process_group(group)
|
||||
|
||||
|
||||
@ -53,7 +53,6 @@ from .testing import (
|
||||
require_torchvision,
|
||||
require_tpu,
|
||||
require_transformer_engine,
|
||||
require_transformer_engine_mxfp8,
|
||||
require_xpu,
|
||||
run_first,
|
||||
skip,
|
||||
|
||||
@ -69,7 +69,7 @@ class TorchTracemalloc:
|
||||
self.begin = torch.npu.memory_allocated()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
|
||||
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||
self.begin = torch.xpu.memory_allocated()
|
||||
elif is_hpu_available():
|
||||
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process
|
||||
|
||||
@ -50,7 +50,7 @@ def test_gather_object(state):
|
||||
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
|
||||
|
||||
|
||||
def test_gather_non_contiguous(state):
|
||||
def test_gather_non_contigous(state):
|
||||
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
|
||||
if state.distributed_type == DistributedType.XLA:
|
||||
return
|
||||
@ -160,8 +160,8 @@ def main():
|
||||
test_gather(state)
|
||||
state.print("testing gather_object")
|
||||
test_gather_object(state)
|
||||
state.print("testing gather non-contiguous")
|
||||
test_gather_non_contiguous(state)
|
||||
state.print("testing gather non-contigous")
|
||||
test_gather_non_contigous(state)
|
||||
state.print("testing broadcast")
|
||||
test_broadcast(state)
|
||||
state.print("testing pad_across_processes")
|
||||
|
||||
@ -35,12 +35,10 @@ from accelerate.utils import (
|
||||
gather,
|
||||
gather_object,
|
||||
is_bf16_available,
|
||||
is_cuda_available,
|
||||
is_datasets_available,
|
||||
is_fp16_available,
|
||||
is_hpu_available,
|
||||
is_ipex_available,
|
||||
is_mps_available,
|
||||
is_pytest_available,
|
||||
is_xpu_available,
|
||||
set_seed,
|
||||
@ -536,7 +534,7 @@ def training_check(use_seedable_sampler=False):
|
||||
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
|
||||
|
||||
# FP32 wrapper check
|
||||
if is_cuda_available() or is_mps_available():
|
||||
if torch.cuda.is_available():
|
||||
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
|
||||
print("Keep fp32 wrapper check.")
|
||||
AcceleratorState._reset_state()
|
||||
|
||||
@ -71,7 +71,6 @@ from ..utils import (
|
||||
is_torchvision_available,
|
||||
is_trackio_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformer_engine_mxfp8_available,
|
||||
is_transformers_available,
|
||||
is_triton_available,
|
||||
is_wandb_available,
|
||||
@ -541,16 +540,6 @@ def require_transformer_engine(test_case):
|
||||
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
|
||||
|
||||
|
||||
def require_transformer_engine_mxfp8(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
|
||||
when transformers engine MXFP8 block scaling isn't available
|
||||
"""
|
||||
return unittest.skipUnless(
|
||||
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
|
||||
@ -598,7 +587,7 @@ def require_torchdata_stateful_dataloader(test_case):
|
||||
def run_first(test_case):
|
||||
"""
|
||||
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
|
||||
guaranteed to run first.
|
||||
garanteed to run first.
|
||||
|
||||
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
|
||||
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
|
||||
@ -617,7 +606,7 @@ def run_first(test_case):
|
||||
class TempDirTestCase(unittest.TestCase):
|
||||
"""
|
||||
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
|
||||
data at the start of a test, and then destroys it at the end of the TestCase.
|
||||
data at the start of a test, and then destroyes it at the end of the TestCase.
|
||||
|
||||
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ class GeneralTracker:
|
||||
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
|
||||
tracking mechanism used by a tracker class (such as the `run` for wandb)
|
||||
|
||||
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and
|
||||
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
|
||||
other functions should occur on the main process or across all processes (by default will use `True`)
|
||||
"""
|
||||
|
||||
|
||||
@ -134,7 +134,6 @@ from .imports import (
|
||||
is_torchvision_available,
|
||||
is_trackio_available,
|
||||
is_transformer_engine_available,
|
||||
is_transformer_engine_mxfp8_available,
|
||||
is_transformers_available,
|
||||
is_triton_available,
|
||||
is_wandb_available,
|
||||
|
||||
@ -314,7 +314,7 @@ def _replace_with_bnb_layers(
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
|
||||
import bitsandbytes as bnb
|
||||
|
||||
@ -371,7 +371,6 @@ class TERecipeKwargs(KwargsHandler):
|
||||
amax_history_len: int = None
|
||||
amax_compute_algo: AmaxComputeAlgorithm = None
|
||||
override_linear_precision: tuple[bool, bool, bool] = None
|
||||
use_mxfp8_block_scaling: bool = None
|
||||
|
||||
def __post_init__(self):
|
||||
env_prefix = "ACCELERATE_FP8_"
|
||||
@ -400,8 +399,6 @@ class TERecipeKwargs(KwargsHandler):
|
||||
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
|
||||
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
|
||||
self.override_linear_precision = (fprop, dgrad, wgrad)
|
||||
if self.use_mxfp8_block_scaling is None:
|
||||
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -683,7 +680,7 @@ class DynamoBackend(str, BaseEnum):
|
||||
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
|
||||
- **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
|
||||
more](https://github.com/intel/intel-extension-for-pytorch).
|
||||
- **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)
|
||||
- **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/)
|
||||
- **HPU_BACKEND** -- Uses HPU backend for inference optimizations.
|
||||
|
||||
"""
|
||||
@ -804,9 +801,9 @@ class DataLoaderConfiguration:
|
||||
all workers.
|
||||
use_seedable_sampler (`bool`, defaults to `False`):
|
||||
Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures
|
||||
training results are fully reproducible using a different sampling technique. While seed-to-seed results
|
||||
may differ, on average the differences are negligible when using multiple different seeds to compare.
|
||||
Should also be ran with [`~utils.set_seed`] for the best results.
|
||||
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||
also be ran with [`~utils.set_seed`] for the best results.
|
||||
data_seed (`int`, defaults to `None`):
|
||||
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
|
||||
will use the current default seed from torch.
|
||||
@ -849,8 +846,8 @@ class DataLoaderConfiguration:
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
|
||||
"Ensures training results are fully reproducible using a different sampling technique. "
|
||||
"While seed-to-seed results may differ, on average the differences are negligible when using"
|
||||
"Ensures training results are fully reproducable using a different sampling technique. "
|
||||
"While seed-to-seed results may differ, on average the differences are neglible when using"
|
||||
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
||||
},
|
||||
)
|
||||
@ -956,7 +953,7 @@ class GradientAccumulationPlugin(KwargsHandler):
|
||||
sync_with_dataloader (`bool`, *optional*, defaults to `True`):
|
||||
Whether to synchronize setting the gradients when at the end of the dataloader.
|
||||
sync_each_batch (`bool`, *optional*):
|
||||
Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory
|
||||
Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory
|
||||
requirements when using gradient accumulation with distributed training, at expense of speed.
|
||||
|
||||
Example:
|
||||
@ -1553,12 +1550,10 @@ class FullyShardedDataParallelPlugin:
|
||||
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
|
||||
Backward prefetch strategy to use. Should be either a `str` or an instance of
|
||||
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
|
||||
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
|
||||
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
|
||||
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
|
||||
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
|
||||
`reduce_dtype`, and `buffer_dtype`.
|
||||
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
|
||||
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
|
||||
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
|
||||
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
|
||||
@ -1637,7 +1632,6 @@ class FullyShardedDataParallelPlugin:
|
||||
mixed_precision_policy: Optional[
|
||||
Union[
|
||||
dict,
|
||||
str,
|
||||
"torch.distributed.fsdp.MixedPrecision",
|
||||
"torch.distributed.fsdp.MixedPrecisionPolicy",
|
||||
]
|
||||
@ -1929,11 +1923,7 @@ class FullyShardedDataParallelPlugin:
|
||||
)
|
||||
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
|
||||
|
||||
if isinstance(self.mixed_precision_policy, str):
|
||||
# override is True since self.mixed_precision_policy is not None
|
||||
# has to be overwritten with the correct mixed precision object
|
||||
self.set_mixed_precision(self.mixed_precision_policy, override=True)
|
||||
elif isinstance(self.mixed_precision_policy, dict):
|
||||
if isinstance(self.mixed_precision_policy, dict):
|
||||
self.set_mixed_precision(self.mixed_precision_policy)
|
||||
if self.mixed_precision_policy is not None:
|
||||
self.validate_mixed_precision_policy()
|
||||
@ -2015,7 +2005,7 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
def set_auto_wrap_policy(self, model):
|
||||
"""
|
||||
Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the
|
||||
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
|
||||
`transformer_cls_to_wrap`
|
||||
"""
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
@ -2263,7 +2253,7 @@ class MegatronLMPlugin:
|
||||
lr_warmup_fraction (`float`, defaults to `None`):
|
||||
Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.
|
||||
min_lr (`float`, defaults to `0`):
|
||||
Minimum value for learning rate. The scheduler clip values below this threshold.
|
||||
Minumum value for learning rate. The scheduler clip values below this threshold.
|
||||
consumed_samples (`List`, defaults to `None`):
|
||||
Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.
|
||||
no_wd_decay_cond (`Optional`, defaults to `None`):
|
||||
@ -2390,7 +2380,7 @@ class MegatronLMPlugin:
|
||||
)
|
||||
min_lr: float = field(
|
||||
default=0,
|
||||
metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."},
|
||||
metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."},
|
||||
)
|
||||
consumed_samples: list[int] = field(
|
||||
default=None,
|
||||
|
||||
@ -149,7 +149,7 @@ def check_cuda_p2p_ib_support():
|
||||
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
|
||||
the 3090.
|
||||
|
||||
Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
|
||||
Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
|
||||
"""
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
|
||||
@ -305,15 +305,13 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
|
||||
optim_state = torch.load(input_optimizer_file, weights_only=True)
|
||||
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
|
||||
else:
|
||||
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
|
||||
|
||||
ckpt_dir = (
|
||||
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
|
||||
if f"{OPTIMIZER_NAME}" not in input_dir
|
||||
else input_dir
|
||||
)
|
||||
logger.info(f"Loading Optimizer from {ckpt_dir}")
|
||||
optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)}
|
||||
optim_state = {"optimizer": optimizer.state_dict()}
|
||||
dist_cp.load(
|
||||
optim_state,
|
||||
checkpoint_id=ckpt_dir,
|
||||
@ -650,7 +648,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
|
||||
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||
|
||||
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||
|
||||
@ -114,14 +114,6 @@ def is_transformer_engine_available():
|
||||
return _is_package_available("transformer_engine", "transformer-engine")
|
||||
|
||||
|
||||
def is_transformer_engine_mxfp8_available():
|
||||
if _is_package_available("transformer_engine", "transformer-engine"):
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
return te.fp8.check_mxfp8_support()[0]
|
||||
return False
|
||||
|
||||
|
||||
def is_lomo_available():
|
||||
return _is_package_available("lomo_optim")
|
||||
|
||||
@ -182,7 +174,7 @@ def is_bf16_available(ignore_tpu=False):
|
||||
if is_xpu_available():
|
||||
return torch.xpu.is_bf16_supported()
|
||||
if is_mps_available():
|
||||
return torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ -414,12 +406,7 @@ def is_npu_available(check_device=False):
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
# NOTE: importing torch_npu may raise error in some envs
|
||||
# e.g. inside cpu-only container with torch_npu installed
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except Exception:
|
||||
return False
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
|
||||
@ -873,7 +873,7 @@ def finish_mpu_init():
|
||||
_set_random_seed(args.seed, args.data_parallel_random_init)
|
||||
|
||||
|
||||
# initialize megatron setup
|
||||
# intialize megatron setup
|
||||
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
|
||||
accelerator.print("Initializing Megatron-LM")
|
||||
assert torch.cuda.is_available(), "Megatron requires CUDA."
|
||||
@ -1344,7 +1344,7 @@ class MegatronEngine(torch.nn.Module):
|
||||
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
|
||||
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
|
||||
|
||||
# We need the sizes of these tensors for the broadcast
|
||||
# We need the sizes of these tensors for the boradcast
|
||||
sizes_list = [
|
||||
prompts_tokens_tensor.size(0), # Batch size
|
||||
prompts_tokens_tensor.size(1),
|
||||
@ -1353,7 +1353,7 @@ class MegatronEngine(torch.nn.Module):
|
||||
# First, broadcast the sizes.
|
||||
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
|
||||
|
||||
# Now that we have the sizes, we can broadcast the tokens
|
||||
# Now that we have the sizes, we can boradcast the tokens
|
||||
# and length tensors.
|
||||
sizes = sizes_tensor.tolist()
|
||||
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)
|
||||
|
||||
@ -2136,10 +2136,6 @@ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
|
||||
return torch.amp.GradScaler("hpu", **kwargs)
|
||||
elif is_xpu_available():
|
||||
return torch.amp.GradScaler("xpu", **kwargs)
|
||||
elif is_mps_available():
|
||||
if not is_torch_version(">=", "2.8.0"):
|
||||
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
|
||||
return torch.amp.GradScaler("mps", **kwargs)
|
||||
else:
|
||||
if is_torch_version(">=", "2.3"):
|
||||
return torch.amp.GradScaler("cuda", **kwargs)
|
||||
|
||||
@ -32,7 +32,6 @@ from .imports import (
|
||||
is_torch_distributed_available,
|
||||
is_torch_xla_available,
|
||||
)
|
||||
from .versions import is_torch_version
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@ -317,8 +316,8 @@ def _gpu_gather(tensor):
|
||||
state = PartialState()
|
||||
gather_op = torch.distributed.all_gather_into_tensor
|
||||
|
||||
# NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
|
||||
if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
|
||||
# FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
|
||||
if state.device.type == "xpu":
|
||||
torch.xpu.synchronize()
|
||||
|
||||
def _gpu_gather_one(tensor):
|
||||
@ -520,7 +519,7 @@ def gather_tensor_shape(tensor):
|
||||
|
||||
def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
|
||||
"""
|
||||
Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
|
||||
Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
|
||||
each worker doesn't need to know its shape when used (and tensor can be `None`)
|
||||
|
||||
Args:
|
||||
@ -732,7 +731,7 @@ def reduce(tensor, reduction="mean", scale=1.0):
|
||||
reduction (`str`, *optional*, defaults to `"mean"`):
|
||||
A reduction method. Can be of "mean", "sum", or "none"
|
||||
scale (`float`, *optional*):
|
||||
A default scaling value to be applied after the reduce, only valid on XLA.
|
||||
A default scaling value to be applied after the reduce, only valied on XLA.
|
||||
|
||||
Returns:
|
||||
The same data structure as `data` with all the tensors reduced.
|
||||
@ -788,7 +787,7 @@ def convert_to_fp32(tensor):
|
||||
|
||||
class ConvertOutputsToFp32:
|
||||
"""
|
||||
Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
precision will be convert back to FP32.
|
||||
|
||||
Args:
|
||||
|
||||
@ -148,34 +148,14 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
|
||||
|
||||
if is_hpu_available():
|
||||
import intel_transformer_engine.recipe as te_recipe
|
||||
|
||||
is_fp8_block_scaling_available = False
|
||||
message = "MXFP8 block scaling is not available on HPU."
|
||||
|
||||
else:
|
||||
import transformer_engine.common.recipe as te_recipe
|
||||
import transformer_engine.pytorch as te
|
||||
|
||||
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
|
||||
|
||||
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
|
||||
if "fp8_format" in kwargs:
|
||||
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
|
||||
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
|
||||
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
|
||||
|
||||
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
|
||||
raise ValueError(f"MXFP8 block scaling is not available: {message}")
|
||||
|
||||
if use_mxfp8_block_scaling:
|
||||
if "amax_compute_algo" in kwargs:
|
||||
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
|
||||
if "amax_history_len" in kwargs:
|
||||
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
|
||||
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
|
||||
else:
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
|
||||
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
|
||||
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
|
||||
|
||||
if hasattr(model.forward, "__func__"):
|
||||
|
||||
@ -316,9 +316,6 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
AcceleratorState._reset_state(True)
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
|
||||
assert plugin.mixed_precision_policy == mp_policy
|
||||
with patch_environment(**env):
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}
|
||||
|
||||
@ -625,7 +625,7 @@ class ToFSDP2Tester(unittest.TestCase):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
to_fsdp2_command(args)
|
||||
|
||||
assert "Config already specifies FSDP2, skipping conversion..." in cm.output[0]
|
||||
assert "Config already specfies FSDP2, skipping conversion..." in cm.output[0]
|
||||
|
||||
# Has to be the last test because it overwrites the config file
|
||||
def test_fsdp2_overwrite(self):
|
||||
|
||||
@ -76,7 +76,7 @@ class TestParallelismConfig:
|
||||
|
||||
return mesh
|
||||
|
||||
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
|
||||
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
|
||||
yield mock_init_mesh
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -31,7 +31,6 @@ from accelerate.test_utils import (
|
||||
require_multi_device,
|
||||
require_torchao,
|
||||
require_transformer_engine,
|
||||
require_transformer_engine_mxfp8,
|
||||
run_first,
|
||||
)
|
||||
from accelerate.test_utils.testing import require_deepspeed, run_command
|
||||
@ -50,8 +49,6 @@ def can_convert_te_model(from_config=False):
|
||||
accelerator_kwargs = {}
|
||||
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
assert accelerator.fp8_enabled, "FP8 is not enabled"
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
|
||||
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.LayerNorm(32, bias=False), torch.nn.Linear(32, 16))
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
@ -112,26 +109,6 @@ class TestTransformerEngine(unittest.TestCase):
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
@require_transformer_engine_mxfp8
|
||||
def test_can_prepare_model_with_mxfp8_block_scaling(self):
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "NO"
|
||||
num_processes: 1
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: TE
|
||||
use_mxfp8_block_scaling: true
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
@require_multi_device
|
||||
def test_can_prepare_model_multi_gpu(self):
|
||||
command = get_launch_command(num_processes=2, monitor_interval=0.1)
|
||||
@ -170,35 +147,6 @@ class TestTransformerEngine(unittest.TestCase):
|
||||
command += ["-m", "tests.test_fp8", "--test_te"]
|
||||
run_command(command)
|
||||
|
||||
@require_deepspeed
|
||||
@require_multi_device
|
||||
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
|
||||
os.environ["ZERO_STAGE"] = str(1)
|
||||
with tempfile.TemporaryDirectory() as dir_name:
|
||||
config_file = Path(dir_name) / "config.yaml"
|
||||
config_file.write_text(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
distributed_type: "DEEPSPEED"
|
||||
deepspeed_config:
|
||||
gradient_clipping: 1.0
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
deepspeed_multinode_launcher: standard
|
||||
num_processes: 2
|
||||
mixed_precision: fp8
|
||||
fp8_config:
|
||||
backend: TE
|
||||
"""
|
||||
)
|
||||
)
|
||||
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
|
||||
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
|
||||
run_command(command)
|
||||
|
||||
|
||||
@require_torchao
|
||||
@require_huggingface_suite
|
||||
|
||||
Reference in New Issue
Block a user