mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) (#3682)
* Feat: init * Feat: add validation + init from kwargs * Fix: minor fixes * Feat: more cleanup * Minor refactor * remove import * adding support for pre-configured device mesh * adding device mesh to fsdp2 * moving mesh dim defn to parralismconfig * tests * WIP device mesh/accelerator validation * WIP more tests * Test Driven Development (TDD) * fixing build_device_mesh * FSDP dim names * adding example * WIP * fixing HSDP * Feat: add back old options * working example * debugging * adding parallelism config to partialstate * Feat: revert ddp changes * Revert DDP * Feat: (untested) update mesh dims and some minor tweaks * adding dp_cp dims * updating comments * WIP * wip 2 * reverting * storing state in accelerator rather than acceleratorstate * Fix: minor tweaks * wip example update * Fixes for non-fsdp2 case * Feat: ensure ddp/tp only works * updating example * updating example * updating examples, fixing state * fixed state * comments * fixing partial state check * linting * comments * removing fn * WIP: fix tp * comments * removing return * reverting upcast * add guards * guards for empty self.parallelism_config * use len on tuple to check if empty * Feat: cleanup example * Feat: some cleanup of example * Feat: add trackio * Fix: improve trackio * Feat: TP works * Feat: some fsdp2 improv * Feat: working examples * handle clipping for tensor parallel * Implicit replicate * Refactor: move to separate file + cleanup + basic comments * Fix: add unadded files, fix circular import * Feat: better readme * Feat: add blog + ultrascale links * Tmp: should_save_model now returns only true * Fix: remove implicit_replication and style * Fix: remove optional * add guard on parallelism_config.tp_enabled * fix import * fixing empty parallelism_config * fix import path for test patch * fixing patch --------- Co-authored-by: S1ro1 <matej.sirovatka@gmail.com> Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@ -2,6 +2,31 @@
|
||||
|
||||
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
|
||||
|
||||
### FSDP2 + ND Parallelism
|
||||
|
||||
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
|
||||
Accelerate then takes care of everything else, such as data parallelism, FSDP, and more to come.
|
||||
Script `nd_parallel.py` showcases just how you can do it. We enable you to configure 3 different parallel dimensions (for now 👀):
|
||||
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
|
||||
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
|
||||
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
|
||||
|
||||
For example, with 8 nodes, you can run the script as such:
|
||||
```bash
|
||||
accelerate launch --num-processes 8 nd_parallel.py \
|
||||
--dp-replicate-size 2 \
|
||||
--dp-shard-size 2 \
|
||||
--tp-size 2 \
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Only use TP intra-node - therefore max TP size you should need is 8, you can also lower this as FSDP (`--dp-shard-size`) can be faster on smaller models with
|
||||
shorter sequence lengths. If you still cannot fit into memory, utilize `--dp-shard-size` as much as you can. Then to scale up to utilize all your GPUs, fill the rest
|
||||
with `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](TODO) or if you wanna dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
</Tip>
|
||||
|
||||
We plan to add more parallelisms in the future, with context parallelism coming soon and pipeline parallelism being planned.
|
||||
|
||||
### FSDP2 + ao Float8Linear
|
||||
|
||||
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
|
||||
@ -34,3 +59,4 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
|
||||
```bash
|
||||
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
|
||||
```
|
||||
|
||||
|
155
examples/fsdp2/nd_parallel.py
Normal file
155
examples/fsdp2/nd_parallel.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with ND parallel using accelerate's ParallelismConfig
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from utils import (
|
||||
PerformanceTracker,
|
||||
create_collate_fn,
|
||||
get_dataset,
|
||||
setup_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dp-replicate-size", type=int, default=1)
|
||||
parser.add_argument("--dp-shard-size", type=int, default=1)
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--sequence-length", type=int, default=1024)
|
||||
parser.add_argument("--num-steps", type=int, default=1000)
|
||||
parser.add_argument("--save-dir", type=str, default="./outputs")
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def forward(model, batch, optimizer, accelerator):
|
||||
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
|
||||
loss_reduce_grp = (
|
||||
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
|
||||
)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
|
||||
return loss
|
||||
|
||||
|
||||
def train(args):
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
dp_shard_size=args.dp_shard_size,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
|
||||
# FSDP needs extra configuration, so we properly shard the model
|
||||
if parallelism_config.dp_shard_enabled:
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with=["wandb"],
|
||||
mixed_precision="bf16",
|
||||
parallelism_config=parallelism_config,
|
||||
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
|
||||
)
|
||||
accelerator.init_trackers("nd_parallel_training")
|
||||
|
||||
# If TP was enabled, we need to tell transformers to prepare the model for us
|
||||
model_kwargs = (
|
||||
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
|
||||
if args.tp_size > 1
|
||||
else {}
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
tokenizer = setup_tokenizer(args.model_name)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=5)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
loss = forward(model, batch, optimizer, accelerator)
|
||||
|
||||
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
|
||||
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
|
||||
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
|
||||
accelerator.print(f"Saving checkpoint at step {step}...")
|
||||
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
|
||||
|
||||
accelerator.log({"loss": loss.item()})
|
||||
|
||||
accelerator.print("Training completed!")
|
||||
|
||||
model.save_pretrained(args.save_dir + f"/{args.model_name}")
|
||||
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
if args.dp_shard_size == 1:
|
||||
# We currently don't support saving with `save_state` when using only
|
||||
# tensor parallelism, fsdp must be enabled
|
||||
warnings.warn(
|
||||
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
|
||||
)
|
||||
train(args)
|
187
examples/fsdp2/nd_parallel_prepared_device_mesh.py
Normal file
187
examples/fsdp2/nd_parallel_prepared_device_mesh.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with ND parallel using accelerate's ParallelismConfig
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Llama-3.2-1B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str)
|
||||
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
|
||||
parser.add_argument("--dp-replicate-size", type=int, default=1)
|
||||
parser.add_argument("--dp-shard-size", type=int, default=1)
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--sequence-length", type=int, default=128)
|
||||
parser.add_argument("--model-save-dir", type=str, default="./outputs")
|
||||
parser.add_argument(
|
||||
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-optimizer",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to save the optimizer state after training.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to train the model.
|
||||
"""
|
||||
args = parse_args()
|
||||
|
||||
set_seed(42)
|
||||
|
||||
if args.model:
|
||||
model_id = args.model
|
||||
else:
|
||||
model_id = MODEL_ID
|
||||
|
||||
model_kwargs = {}
|
||||
accelerator_kwargs = {}
|
||||
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
dp_shard_size=args.dp_shard_size,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
|
||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||
|
||||
if args.tp_size > 1:
|
||||
model_kwargs["tp_size"] = args.tp_size
|
||||
model_kwargs["tp_plan"] = "auto"
|
||||
model_kwargs["device_mesh"] = device_mesh
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)
|
||||
|
||||
if parallelism_config.dp_shard_enabled:
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
cpu_ram_efficient_loading=False,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
|
||||
reshard_after_forward=True,
|
||||
activation_checkpointing=True,
|
||||
state_dict_type="FULL_STATE_DICT",
|
||||
)
|
||||
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="no",
|
||||
**accelerator_kwargs,
|
||||
)
|
||||
|
||||
accelerator.print("Memory usage after model load")
|
||||
accelerator.print(gpu_memory_usage_all())
|
||||
accelerator.print("=" * 20)
|
||||
tokenizer = setup_tokenizer(model_id)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
accelerator.print("Memory usage after model prepare")
|
||||
accelerator.print(gpu_memory_usage_all())
|
||||
accelerator.print("=" * 20)
|
||||
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(100, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
|
||||
|
||||
batch_tokens = batch["input_ids"].shape[1]
|
||||
metrics = performance_tracker.step(batch_tokens)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
f"alloc={metrics['peak_memory_alloc']:.1f}, "
|
||||
f"reserved={metrics['peak_memory_reserved']:.1f}"
|
||||
)
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
accelerator.log(log_metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
accelerator.print("Training completed!")
|
||||
if parallelism_config.dp_shard_enabled and args.save_optimizer:
|
||||
accelerator.print("Saving optimizer state...")
|
||||
save_fsdp_optimizer(
|
||||
fsdp2_plugin,
|
||||
accelerator,
|
||||
optimizer,
|
||||
model,
|
||||
args.model_save_dir + "/opt",
|
||||
)
|
||||
accelerator.print("Optimizer state saved.")
|
||||
accelerator.print("Saving model state...")
|
||||
if args.save_model:
|
||||
model.save_pretrained(args.model_save_dir)
|
||||
accelerator.print(f"Model saved to {args.model_save_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
201
examples/fsdp2/utils.py
Normal file
201
examples/fsdp2/utils.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Common utilities for FSDP2 examples.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
|
||||
"""
|
||||
Load and prepare TinyStories dataset.
|
||||
|
||||
Args:
|
||||
accelerator (Accelerator): Accelerate accelerator instance
|
||||
tokenizer (AutoTokenizer): Hugging Face tokenizer
|
||||
seq_len (int): Sequence length for the dataset
|
||||
|
||||
Returns:
|
||||
Dataset: Packed dataset
|
||||
"""
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
|
||||
|
||||
def tokenize_function(examples):
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"],
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
return_tensors=None,
|
||||
)
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
def create_packed_sequences(examples):
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
with accelerator.main_process_first():
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
return packed_dataset.shuffle(seed=42)
|
||||
|
||||
|
||||
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
|
||||
"""
|
||||
Get the number of flops per token for the model.
|
||||
|
||||
Args:
|
||||
model (AutoModelForCausalLM): Model to get the flops for
|
||||
seq_len (int): Sequence length
|
||||
"""
|
||||
cfg = model.config
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
|
||||
# MLP: 3 matmuls
|
||||
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
|
||||
|
||||
# Attn (w/o dotproduct)
|
||||
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
|
||||
|
||||
# attn (dotproduct) - this scales quadratically with sequence length
|
||||
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
|
||||
|
||||
# we also ignore embeddings and layernorms, etc
|
||||
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
|
||||
|
||||
|
||||
def create_collate_fn():
|
||||
"""Create a collate function for batching."""
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Track training performance metrics."""
|
||||
|
||||
def __init__(self, warmup_steps: int = 10):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all tracking variables."""
|
||||
self.start_time = None
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = True
|
||||
self.step_count = 0
|
||||
|
||||
def step(self, batch_tokens: int) -> dict:
|
||||
"""
|
||||
Update performance tracking with a new step.
|
||||
|
||||
Args:
|
||||
batch_tokens (int): Number of tokens in current batch
|
||||
|
||||
Returns:
|
||||
dict: Performance metrics if past warmup, empty dict otherwise
|
||||
"""
|
||||
self.step_count += 1
|
||||
|
||||
if self.step_count == self.warmup_steps:
|
||||
self.start_time = time.perf_counter()
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = False
|
||||
return {"warmup_completed": True}
|
||||
|
||||
if not self.is_in_warmup and self.start_time is not None:
|
||||
self.num_tokens += batch_tokens
|
||||
total_time = time.perf_counter() - self.start_time
|
||||
steps_from_warmup = self.step_count - self.warmup_steps
|
||||
|
||||
if total_time > 0 and steps_from_warmup > 0:
|
||||
memory_stats = gpu_memory_usage_all()
|
||||
return {
|
||||
"tokens_per_second": self.num_tokens / total_time,
|
||||
"steps_per_second": steps_from_warmup / total_time,
|
||||
"total_tokens": self.num_tokens,
|
||||
"total_time": total_time,
|
||||
**memory_stats,
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
|
||||
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
|
||||
if with_memory:
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
f"alloc={metrics['peak_memory_alloc']:.1f}, "
|
||||
f"reserved={metrics['peak_memory_reserved']:.1f}"
|
||||
)
|
||||
return print_msg
|
||||
|
||||
|
||||
def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
"""Setup tokenizer with proper padding token."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
_BYTES_IN_GIB = 1024**3
|
||||
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
memory_stats = {
|
||||
"peak_memory_active": peak_memory_active,
|
||||
"peak_memory_alloc": peak_memory_alloc,
|
||||
"peak_memory_reserved": peak_memory_reserved,
|
||||
}
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
return memory_stats
|
Reference in New Issue
Block a user