mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
enable regional_compilation benchmark on xpu (#3592)
* enable regional_compilation benchmark on xpu Signed-off-by: Matrix YAO <matrix.yao@intel.com> * Apply style fixes --------- Signed-off-by: Matrix YAO <matrix.yao@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@ -16,6 +16,7 @@ import torch
|
||||
from torch.utils.benchmark import Compare, Timer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from accelerate.test_utils.testing import get_backend
|
||||
from accelerate.utils import compile_regions
|
||||
|
||||
|
||||
@ -33,6 +34,8 @@ REGIONAL_COMPILATION = "Regional compilation"
|
||||
INFRENCE_STMT = "model(input_ids, use_cache=False)"
|
||||
COMPILE_STMT = f"torch._dynamo.reset(); torch._inductor.utils.clear_inductor_caches(); {INFRENCE_STMT}"
|
||||
|
||||
torch_device_type, _, _ = get_backend()
|
||||
|
||||
results = []
|
||||
for model_id in [
|
||||
# non-gated llama models
|
||||
@ -41,7 +44,7 @@ for model_id in [
|
||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"NousResearch/Nous-Hermes-Llama2-13b",
|
||||
]:
|
||||
with torch.device("cuda"):
|
||||
with torch.device(torch_device_type):
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_config(config).to(dtype=torch.float16).eval()
|
||||
|
||||
@ -56,7 +59,9 @@ for model_id in [
|
||||
(regional_compilation_model, REGIONAL_COMPILATION, INFRENCE_TIME, INFRENCE_STMT, INFERENCE_ITERS),
|
||||
]:
|
||||
for batch_size, sequence_length in [(1, 128), (4, 128)]:
|
||||
input_ids = torch.randint(0, 1000, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
|
||||
input_ids = torch.randint(
|
||||
0, 1000, size=(batch_size, sequence_length), dtype=torch.int64, device=torch_device_type
|
||||
)
|
||||
results.append(
|
||||
Timer(
|
||||
label=model_id,
|
||||
|
Reference in New Issue
Block a user