!59 【deploy】合并入参repo_id和model_path

Merge pull request !59 from 张烨槟/dev
This commit is contained in:
张烨槟
2024-12-05 12:05:22 +00:00
committed by i-robot
parent 4c0f4a9993
commit e4f3b55e20
2 changed files with 9 additions and 36 deletions

View File

@ -16,6 +16,8 @@ import textwrap
from ..subcommand import SubCommand
from .mindie import DeployMindie
from .lmdeploy import DeployLMDeploy
from ...utils.constants import DYNAMIC_ARG
from openmind.legacy.pipelines.pipeline_utils import download_from_repo
class Deploy(SubCommand):
@ -53,36 +55,25 @@ class Deploy(SubCommand):
default="mindie",
help="inference backend, choosing from mindie and lmdeploy",
)
self._parser.add_argument(
"--model_path",
type=str,
help="path of the model",
)
self._parser.add_argument(
"--port",
type=int,
default=1025,
help="port for the service-oriented deployment",
)
self._parser.add_argument(
"--max_seq_len",
type=int,
help="maximum length of the sequence",
)
self._parser.add_argument(
"--npu_device_ids",
type=str,
help="npu ids allocated to the model instance",
)
self._parser.add_argument(
"--yaml_path",
type=str,
default=None,
help="Path to the YAML configuration file",
)
def _deploy_cmd(self, args: argparse.Namespace) -> None:
"""Using mindieservice to perform inference via curl"""
args_dict = vars(args)
args_dict.pop("func")
args.model_id = args_dict.pop(DYNAMIC_ARG)
args.host_model_path = download_from_repo(args.model_id)
if args.infer_backend == "mindie":
DeployMindie(args).deploy()
elif args.infer_backend == "lmdeploy":

View File

@ -11,13 +11,9 @@
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import os
import argparse
import subprocess
from ...utils import logging
from ...utils.hub import OM_HUB_CACHE
from ...utils.hub import OpenMindHub
from ...utils.constants import DYNAMIC_ARG
logger = logging.get_logger()
logging.set_verbosity_info()
@ -27,10 +23,8 @@ class DeployLMDeploy:
def __init__(self, args: argparse.Namespace):
self.args = args
self.port = args.port
self.host_model_path = None
args_dict = vars(args)
args_dict.pop("func")
self.model_id = args_dict.pop(DYNAMIC_ARG)
self.model_id = args.model_id
self.host_model_path = args.host_model_path
def __check_requirements_for_lmdeploy(self):
import torch
@ -47,20 +41,8 @@ class DeployLMDeploy:
def deploy(self):
self.__check_requirements_for_lmdeploy()
self._get_host_model_path()
self._start_lmdeploy_service()
def _get_host_model_path(self):
if self.args.model_path:
self.host_model_path = self.args.model_path
if not self.host_model_path:
self.host_model_path = os.path.join(OM_HUB_CACHE, "--".join(self.model_id.split("/")))
self._pull_model()
self.host_model_path = os.path.abspath(self.host_model_path)
def _pull_model(self):
self.host_model_path = OpenMindHub.snapshot_download(self.model_id, cache_dir=self.host_model_path)
def _start_lmdeploy_service(self):
command = [
"lmdeploy",