mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-21 02:33:46 +08:00
* make torch_native_parallelism examples device agnostic Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xxx Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xxx Signed-off-by: YAO Matrix <matrix.yao@intel.com> * Style + deprecation warning --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
|
|
|
from accelerate.utils import ParallelismConfig
|
|
from utils import get_dataset
|
|
|
|
|
|
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--sequence-length", type=int, default=4096)
|
|
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
|
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
|
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
|
|
parser.add_argument("--device-type", type=str, default="auto")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
# If ParallelismConfig is not initialized with __init__, it reads from env vars
|
|
# which were set by using config
|
|
pc = ParallelismConfig()
|
|
args = parse_args()
|
|
|
|
if args.device_type == "auto":
|
|
args.device_type = torch.accelerator.current_accelerator().type
|
|
|
|
model_kwargs = {}
|
|
if pc.tp_enabled:
|
|
model_kwargs["tp_plan"] = "auto"
|
|
model_kwargs["device_mesh"] = pc.build_device_mesh(args.device_type)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, **model_kwargs)
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
packed_dataset = get_dataset(tokenizer, args.sequence_length)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=args.save_dir,
|
|
parallelism_config=pc,
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=1,
|
|
logging_steps=5,
|
|
save_steps=args.checkpoint_frequency,
|
|
learning_rate=5e-5,
|
|
remove_unused_columns=False,
|
|
bf16=True,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
processing_class=tokenizer,
|
|
train_dataset=packed_dataset,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|