mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Logan Adams <loadams@microsoft.com> Signed-off-by: Fabien Dupont <fdupont@redhat.com> Co-authored-by: Fabien Dupont <fabiendupont@fabiendupont.fr>
136 lines
5.9 KiB
Python
136 lines
5.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import pickle
|
|
from packaging import version
|
|
|
|
from .engine_v2 import InferenceEngineV2
|
|
from .config_v2 import RaggedInferenceEngineConfig
|
|
from .checkpoint import HuggingFaceCheckpointEngine
|
|
from .logging import inference_logger
|
|
from .model_implementations import (
|
|
OPTPolicy,
|
|
Llama2Policy,
|
|
MistralPolicy,
|
|
MixtralPolicy,
|
|
FalconPolicy,
|
|
PhiPolicy,
|
|
Phi3Policy,
|
|
QwenPolicy,
|
|
Qwen2Policy,
|
|
Qwen2MoePolicy,
|
|
)
|
|
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
|
|
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
|
|
|
|
|
|
def build_engine_from_ds_checkpoint(path: str,
|
|
engine_config: RaggedInferenceEngineConfig,
|
|
debug_level: int = logging.INFO) -> InferenceEngineV2:
|
|
"""
|
|
Creates an engine from a checkpoint saved by ``InferenceEngineV2``.
|
|
|
|
Arguments:
|
|
path: Path to the checkpoint. This does not need to point to any files in particular,
|
|
just the directory containing the checkpoint.
|
|
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
|
|
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
|
|
value is ``logging.INFO``.
|
|
|
|
Returns:
|
|
Fully initialized inference engine ready to serve queries.
|
|
"""
|
|
|
|
inference_logger(level=debug_level)
|
|
# Load metadata, for grabbing the policy name we'll have all ranks just check for
|
|
# rank 0.
|
|
metadata_filename = make_metadata_filename(path, 0, engine_config.tensor_parallel.tp_size)
|
|
metadata = json.load(open(metadata_filename, "r"))
|
|
metadata = ModelMetadata.parse_raw(metadata)
|
|
|
|
# Get the policy
|
|
try:
|
|
policy_cls: InferenceV2Policy = POLICIES[metadata.policy]
|
|
except KeyError:
|
|
raise ValueError(f"Unknown policy {metadata.policy} for model {path}")
|
|
|
|
# Load the model config
|
|
model_config = pickle.load(open(os.path.join(path, "ds_model_config.pkl"), "rb"))
|
|
policy = policy_cls(model_config, inf_checkpoint_path=path)
|
|
|
|
return InferenceEngineV2(policy, engine_config)
|
|
|
|
|
|
def build_hf_engine(path: str,
|
|
engine_config: RaggedInferenceEngineConfig,
|
|
debug_level: int = logging.INFO) -> InferenceEngineV2:
|
|
"""
|
|
Build an InferenceV2 engine for HuggingFace models. This can accept both a HuggingFace
|
|
model name or a path to an Inference-V2 checkpoint.
|
|
|
|
Arguments:
|
|
path: Path to the checkpoint. This does not need to point to any files in particular,
|
|
just the directory containing the checkpoint.
|
|
engine_config: Engine configuration. See ``RaggedInferenceEngineConfig`` for details.
|
|
debug_level: Logging level to use. Unless you are actively seeing issues, the recommended
|
|
value is ``logging.INFO``.
|
|
|
|
Returns:
|
|
Fully initialized inference engine ready to serve queries.
|
|
"""
|
|
|
|
if os.path.exists(os.path.join(path, "ds_model_config.pkl")):
|
|
return build_engine_from_ds_checkpoint(path, engine_config, debug_level=debug_level)
|
|
else:
|
|
# Set up logging
|
|
inference_logger(level=debug_level)
|
|
# get HF checkpoint engine
|
|
checkpoint_engine = HuggingFaceCheckpointEngine(path)
|
|
|
|
# get model config from HF AutoConfig
|
|
model_config = checkpoint_engine.model_config
|
|
|
|
# get the policy
|
|
# TODO: generalize this to other models
|
|
if model_config.model_type == "opt":
|
|
if not model_config.do_layer_norm_before:
|
|
raise ValueError(
|
|
"Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/deepspeedai/DeepSpeed-MII/issues"
|
|
)
|
|
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "llama":
|
|
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "mistral":
|
|
# Ensure we're using the correct version of transformers for mistral
|
|
import transformers
|
|
assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \
|
|
f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}"
|
|
policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "mixtral":
|
|
# Ensure we're using the correct version of transformers for mistral
|
|
import transformers
|
|
assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \
|
|
f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}"
|
|
policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "falcon":
|
|
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "phi":
|
|
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "phi3":
|
|
policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "qwen":
|
|
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "qwen2":
|
|
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
|
|
elif model_config.model_type == "qwen2_moe":
|
|
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
|
|
else:
|
|
raise ValueError(f"Unsupported model type {model_config.model_type}")
|
|
|
|
return InferenceEngineV2(policy, engine_config)
|