mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-16 15:24:34 +08:00
Compare commits
11 Commits
v0.27.0
...
fix-genera
| Author | SHA1 | Date | |
|---|---|---|---|
| c8f6f79199 | |||
| 7ca4bccf5d | |||
| 8792a8c5af | |||
| e3f6b99b68 | |||
| df7779aa44 | |||
| 77f8e92b94 | |||
| 449eb8d9ef | |||
| 9eef9dd4b8 | |||
| 06f04a998f | |||
| 2767bb19ac | |||
| e713e28eaa |
98
src/accelerate/inference.py
Normal file
98
src/accelerate/inference.py
Normal file
@ -0,0 +1,98 @@
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Literal
|
||||
|
||||
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
|
||||
from pippy.PipelineStage import PipelineStage
|
||||
|
||||
from .state import PartialState
|
||||
from .utils import (
|
||||
calculate_maximum_sizes,
|
||||
convert_bytes,
|
||||
infer_auto_device_map,
|
||||
send_to_device,
|
||||
)
|
||||
|
||||
|
||||
ParallelMode = Literal["sequential", "pipeline_parallel"]
|
||||
|
||||
|
||||
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None):
|
||||
"""
|
||||
Calculates the device map for `model` with an offset for PiPPy
|
||||
"""
|
||||
if num_processes == 1:
|
||||
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
|
||||
model_size, shared = calculate_maximum_sizes(model)
|
||||
|
||||
# Split into `n` chunks for each GPU
|
||||
memory = (model_size + shared[0]) / num_processes
|
||||
memory = convert_bytes(memory)
|
||||
value, ending = memory.split(" ")
|
||||
|
||||
# Add a chunk to deal with potential extra shared memory instances
|
||||
memory = math.ceil(float(value)) * 1.1
|
||||
memory = f"{memory} {ending}"
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
max_memory={i: memory for i in range(num_processes)},
|
||||
no_split_module_classes=no_split_module_classes,
|
||||
clean_result=False,
|
||||
)
|
||||
return device_map
|
||||
|
||||
|
||||
def build_pipeline(model, split_points, args, kwargs) -> PipelineStage:
|
||||
"""
|
||||
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
|
||||
in needed `args` and `kwargs` as the model needs on the CPU.
|
||||
"""
|
||||
# We need to annotate the split points in the model for PiPPy
|
||||
state = PartialState()
|
||||
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
|
||||
pipe = Pipe.from_tracing(model, num_chunks=state.num_processes, example_args=args, example_kwargs=kwargs)
|
||||
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
|
||||
|
||||
return stage
|
||||
|
||||
|
||||
def pippy_forward(forward, *args, **kwargs):
|
||||
state = PartialState()
|
||||
output = None
|
||||
if state.num_processes == 1:
|
||||
output = forward(*args, **kwargs)
|
||||
elif state.is_local_main_process:
|
||||
forward(*args, **kwargs)
|
||||
elif state.is_last_process:
|
||||
output = forward()
|
||||
else:
|
||||
forward()
|
||||
return output
|
||||
|
||||
|
||||
def prepare_pippy(model, split_points="auto", no_split_module_classes=[], example_args=(), example_kwargs={}):
|
||||
"""
|
||||
Wraps `model` for PipelineParallelism
|
||||
"""
|
||||
state = PartialState()
|
||||
example_args = send_to_device(example_args, "cpu")
|
||||
example_kwargs = send_to_device(example_kwargs, "cpu")
|
||||
if split_points == "auto":
|
||||
device_map = generate_device_map(model, state.num_processes, no_split_module_classes=no_split_module_classes)
|
||||
split_points = []
|
||||
for i in range(1, state.num_processes):
|
||||
split_points.append(next(k for k, v in device_map.items() if v == i))
|
||||
stage = build_pipeline(model, split_points, example_args, example_kwargs)
|
||||
model._original_forward = model.forward
|
||||
model._original_call = model.__call__
|
||||
model.pippy_stage = stage
|
||||
model.hf_split_points = split_points
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
return pippy_forward(stage.forward, *args, **kwargs)
|
||||
|
||||
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
|
||||
# model_forward = MethodType(forward, model)
|
||||
# forward.__wrapped__ = model_forward
|
||||
model.forward = forward
|
||||
return model
|
||||
@ -0,0 +1,97 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
GPT2Config,
|
||||
GPT2ForSequenceClassification,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from accelerate import PartialState
|
||||
from accelerate.inference import prepare_pippy
|
||||
from accelerate.utils import DistributedType, send_to_device, set_seed
|
||||
|
||||
|
||||
model_to_config = {
|
||||
"t5": (T5ForConditionalGeneration, T5Config, 1024),
|
||||
"bert": (BertForMaskedLM, BertConfig, 512),
|
||||
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
|
||||
}
|
||||
|
||||
|
||||
def get_model_and_data(model_name, device, num_processes: int = 2):
|
||||
initializer, config, seq_len = model_to_config[model_name]
|
||||
config = config()
|
||||
model = initializer(config)
|
||||
return model, torch.randint(
|
||||
low=0,
|
||||
high=config.vocab_size,
|
||||
size=(num_processes, seq_len),
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
def test_gpt2():
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model, inputs = get_model_and_data("gpt2", "cpu", state.num_processes)
|
||||
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
|
||||
# For inference args need to be a tuple
|
||||
inputs = inputs.to("cuda")
|
||||
with torch.no_grad():
|
||||
output = model(inputs)
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
if not state.is_last_process:
|
||||
assert output is None, "Output was not generated on just the last process!"
|
||||
else:
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
|
||||
|
||||
def test_t5():
|
||||
set_seed(42)
|
||||
state = PartialState()
|
||||
model, inputs = get_model_and_data("t5", "cpu", state.num_processes)
|
||||
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
|
||||
model = prepare_pippy(
|
||||
model,
|
||||
no_split_module_classes=model._no_split_modules,
|
||||
example_kwargs=example_inputs,
|
||||
)
|
||||
# For inference args need to be a tuple
|
||||
inputs = send_to_device(example_inputs, "cuda:0")
|
||||
with torch.no_grad():
|
||||
output = model(*inputs.values())
|
||||
# Zach: Check that we just grab the real outputs we need at the end
|
||||
if not state.is_last_process:
|
||||
assert output is None, "Output was not generated on just the last process!"
|
||||
else:
|
||||
assert output is not None, "Output was not generated in the last process!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
state = PartialState()
|
||||
state.print("Testing pippy integration...")
|
||||
if state.distributed_type == DistributedType.MULTI_GPU:
|
||||
state.print("Testing GPT2...")
|
||||
test_gpt2()
|
||||
state.print("Testing T5...")
|
||||
test_t5()
|
||||
else:
|
||||
print("Less than two GPUs found, not running tests!")
|
||||
@ -69,6 +69,7 @@ from .imports import (
|
||||
is_msamp_available,
|
||||
is_npu_available,
|
||||
is_pandas_available,
|
||||
is_pippy_available,
|
||||
is_rich_available,
|
||||
is_sagemaker_available,
|
||||
is_tensorboard_available,
|
||||
|
||||
@ -126,6 +126,10 @@ def is_deepspeed_available():
|
||||
return _is_package_available("deepspeed")
|
||||
|
||||
|
||||
def is_pippy_available():
|
||||
return _is_package_available("torchpippy")
|
||||
|
||||
|
||||
def is_bf16_available(ignore_tpu=False):
|
||||
"Checks if bf16 is supported, optionally ignoring the TPU"
|
||||
if is_tpu_available():
|
||||
|
||||
Reference in New Issue
Block a user