mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
* enable regional_compilation benchmark on xpu Signed-off-by: Matrix YAO <matrix.yao@intel.com> * Apply style fixes --------- Signed-off-by: Matrix YAO <matrix.yao@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
from torch.utils.benchmark import Compare, Timer
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
from accelerate.test_utils.testing import get_backend
|
|
from accelerate.utils import compile_regions
|
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
COMPILE_ITERS = 2
|
|
INFERENCE_ITERS = 100
|
|
|
|
BASELINE = "Baseline"
|
|
COMPILE_TIME = "Compile time"
|
|
INFRENCE_TIME = "Inference time"
|
|
FULL_COMPILATION = "Full compilation"
|
|
REGIONAL_COMPILATION = "Regional compilation"
|
|
|
|
INFRENCE_STMT = "model(input_ids, use_cache=False)"
|
|
COMPILE_STMT = f"torch._dynamo.reset(); torch._inductor.utils.clear_inductor_caches(); {INFRENCE_STMT}"
|
|
|
|
torch_device_type, _, _ = get_backend()
|
|
|
|
results = []
|
|
for model_id in [
|
|
# non-gated llama models
|
|
"NousResearch/Llama-3.2-1B",
|
|
"NousResearch/Hermes-3-Llama-3.2-3B",
|
|
"NousResearch/Hermes-3-Llama-3.1-8B",
|
|
"NousResearch/Nous-Hermes-Llama2-13b",
|
|
]:
|
|
with torch.device(torch_device_type):
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
model = AutoModelForCausalLM.from_config(config).to(dtype=torch.float16).eval()
|
|
|
|
full_compilation_model = torch.compile(model)
|
|
regional_compilation_model = compile_regions(model)
|
|
|
|
for model, sub_label, description, stmt, iters in [
|
|
(model, BASELINE, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),
|
|
(full_compilation_model, FULL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS),
|
|
(full_compilation_model, FULL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),
|
|
(regional_compilation_model, REGIONAL_COMPILATION, COMPILE_TIME, COMPILE_STMT, COMPILE_ITERS),
|
|
(regional_compilation_model, REGIONAL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),
|
|
]:
|
|
for batch_size, sequence_length in [(1, 128), (4, 128)]:
|
|
input_ids = torch.randint(
|
|
0, 1000, size=(batch_size, sequence_length), dtype=torch.int64, device=torch_device_type
|
|
)
|
|
results.append(
|
|
Timer(
|
|
label=model_id,
|
|
sub_label=sub_label,
|
|
description=f"{description} ({batch_size}x{sequence_length})",
|
|
globals={"model": model, "input_ids": input_ids},
|
|
stmt=stmt,
|
|
).timeit(number=iters)
|
|
)
|
|
|
|
compare = Compare(results)
|
|
compare.colorize()
|
|
compare.print()
|