Files
DeepSpeed/deepspeed/inference/v2/engine_factory.py
Olatunji Ruwase fd40516923 Update GH org references (#6998)
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Fabien Dupont <fdupont@redhat.com>
Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
2025-02-05 00:56:50 +00:00

136 lines
5.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import json
import logging
import os
import pickle
from packaging import version
from .engine_v2 import InferenceEngineV2
from .config_v2 import RaggedInferenceEngineConfig
from .checkpoint import HuggingFaceCheckpointEngine
from .logging import inference_logger
from .model_implementations import (
OPTPolicy,
Llama2Policy,
MistralPolicy,
MixtralPolicy,
FalconPolicy,
PhiPolicy,
Phi3Policy,
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
def build_engine_from_ds_checkpoint(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Creates an engine from a checkpoint saved by ``InferenceEngineV2``.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""
inference_logger(level=debug_level)
# Load metadata, for grabbing the policy name we'll have all ranks just check for
# rank 0.
metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size)
metadata = json.load(open(metadata_filename, "r"))
metadata = ModelMetadata.parse_raw(metadata)
# Get the policy
try:
policy_cls: InferenceV2Policy = POLICIES[metadata.policy]
except KeyError:
raise ValueError(f"Unknown policy {metadata.policy} for model {path}")
# Load the model config
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
policy = policy_cls(model_config, inf_checkpoint_path=path)
return InferenceEngineV2(policy, engine_config)
def build_hf_engine(path: str,
engine_config: RaggedInferenceEngineConfig,
debug_level: int = logging.INFO) -> InferenceEngineV2:
"""
Build an InferenceV2 engine for HuggingFace models. This can accept both a HuggingFace
model name or a path to an Inference-V2 checkpoint.
Arguments:
path: Path to the checkpoint. This does not need to point to any files in particular,
just the directory containing the checkpoint.
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
value is ``logging.INFO``.
Returns:
Fully initialized inference engine ready to serve queries.
"""
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):
return build_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level)
else:
# Set up logging
inference_logger(level=debug_level)
# get HF checkpoint engine
checkpoint_engine = HuggingFaceCheckpointEngine(path)
# get model config from HF AutoConfig
model_config = checkpoint_engine.model_config
# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
if not model_config.do_layer_norm_before:
raise ValueError(
"Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/deepspeedai/DeepSpeed-MII/issues"
)
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mistral":
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "mixtral":
# Ensure we're using the correct version of transformers for mistral
import transformers
assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \
f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}"
policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "falcon":
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi":
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi3":
policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen":
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2_moe":
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")
return InferenceEngineV2(policy, engine_config)