mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-11 22:44:28 +08:00
Compare commits
6 Commits
v1.5.0
...
pippy-inte
| Author | SHA1 | Date | |
|---|---|---|---|
| 36ab515545 | |||
| b866044193 | |||
| 4a1bdc8feb | |||
| ae0d44d3dd | |||
| 363e17e768 | |||
| e304372470 |
150
src/accelerate/inference.py
Normal file
150
src/accelerate/inference.py
Normal file
@ -0,0 +1,150 @@
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
|
||||
from pippy.PipelineStage import PipelineStage
|
||||
from torch import nn
|
||||
|
||||
from .big_modeling import dispatch_model
|
||||
from .state import PartialState
|
||||
from .utils import (
|
||||
calculate_maximum_sizes,
|
||||
convert_bytes,
|
||||
infer_auto_device_map,
|
||||
is_pippy_available,
|
||||
)
|
||||
|
||||
|
||||
ParallelMode = Literal["sequential", "pipeline_parallel"]
|
||||
|
||||
|
||||
class InferenceHandler:
|
||||
"""
|
||||
Base class for handling different backends for `device_map="auto"` inference.
|
||||
|
||||
Supports the native accelerate version as well as utilizing PiPPy.
|
||||
"""
|
||||
|
||||
def __init__(self, device_map: str = "auto", parallel_mode: ParallelMode = "sequential"):
|
||||
self.model = None
|
||||
self.device_map = device_map
|
||||
self.state = PartialState()
|
||||
self.parallel_mode = parallel_mode
|
||||
if parallel_mode == "pipeline_parallel" and not is_pippy_available():
|
||||
raise RuntimeError("PiPPy is not installed, but is required for pipeline parallel inference.")
|
||||
|
||||
# Attrs for native pipeline parallelism
|
||||
self.pipeline = None
|
||||
self.scheduler = None
|
||||
|
||||
@staticmethod
|
||||
def generate_device_map(model: nn.Module, parallel_mode: str, num_processes: int = 1):
|
||||
if parallel_mode == "sequential":
|
||||
# No change or adjustment is needed
|
||||
return infer_auto_device_map(
|
||||
model,
|
||||
no_split_module_classes=model._no_split_modules,
|
||||
clean_result=False,
|
||||
)
|
||||
elif parallel_mode == "pipeline_parallel":
|
||||
# Calculate the maximum size of the model weights
|
||||
model_size, shared = calculate_maximum_sizes(model)
|
||||
# Split memory based on the number of devices
|
||||
memory = (model_size + shared[0]) / num_processes
|
||||
memory = convert_bytes(memory)
|
||||
value, ending = memory.split(" ")
|
||||
|
||||
# Add a chunk to deal with placement issues on `pippy`
|
||||
memory = math.ceil(float(value)) * 1.1
|
||||
memory = f"{memory} {ending}"
|
||||
# Create a device map based on the memory
|
||||
max_memory = {i: memory for i in range(num_processes)}
|
||||
return infer_auto_device_map(
|
||||
model,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=model._no_split_modules,
|
||||
clean_result=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown parallel mode: {parallel_mode}. Expected either `sequential` or `pipeline_parallel`."
|
||||
)
|
||||
|
||||
def prepare_pippy(self, model: nn.Module):
|
||||
"""
|
||||
Prepares a model for inference based on `device_map` for pipeline parallelism.
|
||||
"""
|
||||
if self.device_map == "auto":
|
||||
self.device_map = self.generate_device_map(model, "pipeline_parallel", self.state.num_processes)
|
||||
model._original_forward = model.forward
|
||||
# get all the split points for each device
|
||||
split_points = []
|
||||
for i in range(1, self.state.num_processes):
|
||||
split_points.append(next(k for k, v in self.device_map.items() if v == i))
|
||||
|
||||
# Annotate the model with the split points
|
||||
for split_point in split_points:
|
||||
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING})
|
||||
|
||||
model._original_call = model.__call__
|
||||
model._original_forward = model.forward
|
||||
model.__call__ = partial(self.__call__, model=model)
|
||||
return model
|
||||
|
||||
def prepare_native(self, model: nn.Module):
|
||||
"""
|
||||
Prepares a model for inference based on `device_map` for sequential parallelism.
|
||||
"""
|
||||
if self.device_map == "auto":
|
||||
self.device_map = self.generate_device_map(model, "sequential", self.state.num_processes)
|
||||
|
||||
model = dispatch_model(model, self.device_map)
|
||||
return model
|
||||
|
||||
def prepare(self, model: nn.Module):
|
||||
"""
|
||||
Prepares a model for inference based on `device_map` and `parallel_mode`.
|
||||
"""
|
||||
if self.parallel_mode == "sequential":
|
||||
model = self.prepare_native(model)
|
||||
elif self.parallel_mode == "pipeline_parallel":
|
||||
# Send the model to the device for now
|
||||
model.to(self.device)
|
||||
model = self.prepare_pippy(model)
|
||||
return model
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.state.device
|
||||
|
||||
def __call__(self, model, *args, **kwargs):
|
||||
if model is None:
|
||||
raise RuntimeError("Model must be prepared before inference is performed. Please call `prepare` first.")
|
||||
|
||||
with torch.inference_mode():
|
||||
if self.parallel_mode == "sequential":
|
||||
return model(*args, **kwargs)
|
||||
elif self.parallel_mode == "pipeline_parallel":
|
||||
if self.pipeline is None:
|
||||
# We need to do our first trace quickly over the model
|
||||
self.pipeline = Pipe.from_tracing(
|
||||
model,
|
||||
num_chunks=self.state.num_processes,
|
||||
example_args=args,
|
||||
example_kwargs=kwargs,
|
||||
)
|
||||
if self.scheduler is None:
|
||||
# Create a schedule runtime
|
||||
self.scheduler = PipelineStage(
|
||||
self.pipeline,
|
||||
self.state.local_process_index,
|
||||
device=self.state.device,
|
||||
)
|
||||
if self.state.is_local_main_process:
|
||||
# convert kwargs to a tuple, this has been fixed on main
|
||||
args = tuple(kwargs.values())
|
||||
return self.scheduler(*args)
|
||||
else:
|
||||
return self.scheduler(*())
|
||||
@ -67,6 +67,7 @@ from .imports import (
|
||||
is_mps_available,
|
||||
is_npu_available,
|
||||
is_pandas_available,
|
||||
is_pippy_available,
|
||||
is_rich_available,
|
||||
is_sagemaker_available,
|
||||
is_tensorboard_available,
|
||||
|
||||
@ -110,6 +110,10 @@ def is_deepspeed_available():
|
||||
return _is_package_available("deepspeed")
|
||||
|
||||
|
||||
def is_pippy_available():
|
||||
return _is_package_available("torchpippy")
|
||||
|
||||
|
||||
def is_bf16_available(ignore_tpu=False):
|
||||
"Checks if bf16 is supported, optionally ignoring the TPU"
|
||||
if is_tpu_available():
|
||||
|
||||
@ -21,7 +21,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -923,6 +923,7 @@ def infer_auto_device_map(
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None,
|
||||
verbose: bool = False,
|
||||
clean_result: bool = True,
|
||||
):
|
||||
"""
|
||||
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
|
||||
@ -956,6 +957,8 @@ def infer_auto_device_map(
|
||||
all weights).
|
||||
verbose (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to provide debugging statements as the function builds the device_map.
|
||||
clean_result (`bool`, *optional*, defaults to `True`):
|
||||
Clean the resulting device_map by grouping all submodules that go on the same device together.
|
||||
"""
|
||||
# Get default / clean up max_memory
|
||||
max_memory = get_max_memory(max_memory)
|
||||
@ -985,7 +988,7 @@ def infer_auto_device_map(
|
||||
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
|
||||
)
|
||||
|
||||
device_map = {}
|
||||
device_map = OrderedDict()
|
||||
current_device = 0
|
||||
current_memory_used = 0
|
||||
|
||||
@ -1153,7 +1156,9 @@ def infer_auto_device_map(
|
||||
current_memory_used += module_size
|
||||
device_map[name] = devices[current_device]
|
||||
|
||||
return clean_device_map(device_map)
|
||||
if clean_result:
|
||||
device_map = clean_device_map(device_map)
|
||||
return device_map
|
||||
|
||||
|
||||
def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, torch.device]]):
|
||||
|
||||
Reference in New Issue
Block a user