mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
* enable test_cli & test_example cases on XPU Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * remove print Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix ci issue Signed-off-by: YAO Matrix <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com>
93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import time
|
|
|
|
import torch
|
|
from transformers import AutoModelForMaskedLM
|
|
|
|
from accelerate import PartialState, prepare_pippy
|
|
from accelerate.test_utils import torch_device
|
|
from accelerate.utils import set_seed
|
|
|
|
|
|
synchronize_func = getattr(torch, torch_device, torch.cuda).synchronize
|
|
|
|
# Set the random seed to have reproducable outputs
|
|
set_seed(42)
|
|
|
|
# Create an example model
|
|
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
|
|
model.eval()
|
|
|
|
# Input configs
|
|
# Create example inputs for the model
|
|
input = torch.randint(
|
|
low=0,
|
|
high=model.config.vocab_size,
|
|
size=(1, 512), # bs x seq_len
|
|
device="cpu",
|
|
dtype=torch.int64,
|
|
requires_grad=False,
|
|
)
|
|
|
|
|
|
# Create a pipeline stage from the model
|
|
# Using `auto` is equivalent to letting `device_map="auto"` figure
|
|
# out device mapping and will also split the model according to the
|
|
# number of total GPUs available if it fits on one GPU
|
|
model = prepare_pippy(model, split_points="auto", example_args=(input,))
|
|
|
|
# You can pass `gather_output=True` to have the output from the model
|
|
# available on all GPUs
|
|
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
|
|
|
|
# Create new inputs of the expected size (n_processes)
|
|
input = torch.randint(
|
|
low=0,
|
|
high=model.config.vocab_size,
|
|
size=(2, 512), # bs x seq_len
|
|
device="cpu",
|
|
dtype=torch.int64,
|
|
requires_grad=False,
|
|
)
|
|
|
|
# Move the inputs to the first device
|
|
input = input.to(torch_device)
|
|
|
|
# Take an average of 5 times
|
|
# Measure first batch
|
|
synchronize_func()
|
|
start_time = time.time()
|
|
with torch.no_grad():
|
|
output = model(input)
|
|
synchronize_func()
|
|
end_time = time.time()
|
|
first_batch = end_time - start_time
|
|
|
|
# Now that hpu is init, measure after
|
|
synchronize_func()
|
|
start_time = time.time()
|
|
for i in range(5):
|
|
with torch.no_grad():
|
|
output = model(input)
|
|
synchronize_func()
|
|
end_time = time.time()
|
|
|
|
# The outputs are only on the final process by default
|
|
if PartialState().is_last_process:
|
|
output = torch.stack(tuple(output[0]))
|
|
print(f"Time of first pass: {first_batch}")
|
|
print(f"Average time per batch: {(end_time - start_time) / 5}")
|
|
PartialState().destroy_process_group()
|