Compare commits

...

11 Commits

4 changed files with 200 additions and 0 deletions

View File

@ -0,0 +1,98 @@
import math
from types import MethodType
from typing import Literal
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage
from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
infer_auto_device_map,
send_to_device,
)
ParallelMode = Literal["sequential", "pipeline_parallel"]
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
model_size, shared = calculate_maximum_sizes(model)
# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")
# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
device_map = infer_auto_device_map(
model,
max_memory={i: memory for i in range(num_processes)},
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
return device_map
def build_pipeline(model, split_points, args, kwargs) -> PipelineStage:
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.
"""
# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
pipe = Pipe.from_tracing(model, num_chunks=state.num_processes, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
return stage
def pippy_forward(forward, *args, **kwargs):
state = PartialState()
output = None
if state.num_processes == 1:
output = forward(*args, **kwargs)
elif state.is_local_main_process:
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
return output
def prepare_pippy(model, split_points="auto", no_split_module_classes=[], example_args=(), example_kwargs={}):
"""
Wraps `model` for PipelineParallelism
"""
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
if split_points == "auto":
device_map = generate_device_map(model, state.num_processes, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, state.num_processes):
split_points.append(next(k for k, v in device_map.items() if v == i))
stage = build_pipeline(model, split_points, example_args, example_kwargs)
model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage
model.hf_split_points = split_points
def forward(*args, **kwargs):
return pippy_forward(stage.forward, *args, **kwargs)
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# model_forward = MethodType(forward, model)
# forward.__wrapped__ = model_forward
model.forward = forward
return model

View File

@ -0,0 +1,97 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
T5Config,
T5ForConditionalGeneration,
)
from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.utils import DistributedType, send_to_device, set_seed
model_to_config = {
"t5": (T5ForConditionalGeneration, T5Config, 1024),
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}
def get_model_and_data(model_name, device, num_processes: int = 2):
initializer, config, seq_len = model_to_config[model_name]
config = config()
model = initializer(config)
return model, torch.randint(
low=0,
high=config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)
def test_gpt2():
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data("gpt2", "cpu", state.num_processes)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
def test_t5():
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data("t5", "cpu", state.num_processes)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
no_split_module_classes=model._no_split_modules,
example_kwargs=example_inputs,
)
# For inference args need to be a tuple
inputs = send_to_device(example_inputs, "cuda:0")
with torch.no_grad():
output = model(*inputs.values())
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
state.print("Testing T5...")
test_t5()
else:
print("Less than two GPUs found, not running tests!")

View File

@ -69,6 +69,7 @@ from .imports import (
is_msamp_available,
is_npu_available,
is_pandas_available,
is_pippy_available,
is_rich_available,
is_sagemaker_available,
is_tensorboard_available,

View File

@ -126,6 +126,10 @@ def is_deepspeed_available():
return _is_package_available("deepspeed")
def is_pippy_available():
return _is_package_available("torchpippy")
def is_bf16_available(ignore_tpu=False):
"Checks if bf16 is supported, optionally ignoring the TPU"
if is_tpu_available():