!223 Update cli and model related prompt
Merge pull request !223 from 金勇旭/prompt
This commit is contained in:
@ -19,12 +19,13 @@ import subprocess
|
||||
import sys
|
||||
import random
|
||||
|
||||
from openmind.utils.constants import Command
|
||||
from openmind.utils.constants import COMMANDS
|
||||
from openmind.cli.chat import run_chat
|
||||
from openmind.cli.env import run_env
|
||||
from openmind.archived.cli_legacy.model_cli import run_pull, run_push, run_rm, run_list
|
||||
from openmind.archived.cli_legacy.pipeline_cli import run_pipeline
|
||||
from openmind.utils import is_torch_available
|
||||
from openmind.utils.arguments_utils import print_formatted_table
|
||||
|
||||
# Compatible with MindSpore
|
||||
if is_torch_available():
|
||||
@ -44,51 +45,66 @@ def get_device_count():
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
command_cli = sys.argv[1]
|
||||
if command_cli == Command.TRAIN:
|
||||
if get_device_count() >= 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
command = [
|
||||
"torchrun",
|
||||
"--nnodes",
|
||||
os.environ.get("NNODES", "1"),
|
||||
"--node_rank",
|
||||
os.environ.get("RANK", "0"),
|
||||
"--nproc_per_node",
|
||||
os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
master_port,
|
||||
train.__file__,
|
||||
] + sys.argv[2::]
|
||||
subprocess.run(command)
|
||||
else:
|
||||
raise ValueError("There is no npu devices to launch finetune workflow")
|
||||
elif command_cli == Command.LIST:
|
||||
run_list()
|
||||
elif command_cli == Command.EVAL:
|
||||
run_eval()
|
||||
elif command_cli == Command.PULL:
|
||||
run_pull()
|
||||
elif command_cli == Command.PUSH:
|
||||
run_push()
|
||||
elif command_cli == Command.RM:
|
||||
run_rm()
|
||||
elif command_cli == Command.CHAT:
|
||||
run_chat()
|
||||
elif command_cli == Command.RUN:
|
||||
run_pipeline()
|
||||
elif command_cli == Command.ENV:
|
||||
run_env()
|
||||
elif command_cli == Command.DEPLOY:
|
||||
run_deploy()
|
||||
elif command_cli == Command.EXPORT:
|
||||
run_export()
|
||||
def run_train():
|
||||
if get_device_count() >= 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
command = [
|
||||
"torchrun",
|
||||
"--nnodes",
|
||||
os.environ.get("NNODES", "1"),
|
||||
"--node_rank",
|
||||
os.environ.get("RANK", "0"),
|
||||
"--nproc_per_node",
|
||||
os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
master_port,
|
||||
train.__file__,
|
||||
] + sys.argv[2::]
|
||||
subprocess.run(command)
|
||||
else:
|
||||
raise ValueError(f"Currently command {command_cli} is not supported")
|
||||
raise ValueError("There is no npu devices to launch finetune workflow")
|
||||
|
||||
|
||||
def print_help():
|
||||
header = ["Commands", "Description"]
|
||||
commands_info = [[info.cmd, info.desc] for info in COMMANDS.values()]
|
||||
print_formatted_table(commands_info, header)
|
||||
|
||||
|
||||
def main():
|
||||
command_cli = sys.argv[1] if len(sys.argv) != 1 else None
|
||||
if command_cli == COMMANDS["TRAIN"].cmd:
|
||||
run_train()
|
||||
elif command_cli == COMMANDS["LIST"].cmd:
|
||||
run_list()
|
||||
elif command_cli == COMMANDS["EVAL"].cmd:
|
||||
run_eval()
|
||||
elif command_cli == COMMANDS["PULL"].cmd:
|
||||
run_pull()
|
||||
elif command_cli == COMMANDS["PUSH"].cmd:
|
||||
run_push()
|
||||
elif command_cli == COMMANDS["RM"].cmd:
|
||||
run_rm()
|
||||
elif command_cli == COMMANDS["CHAT"].cmd:
|
||||
run_chat()
|
||||
elif command_cli == COMMANDS["RUN"].cmd:
|
||||
run_pipeline()
|
||||
elif command_cli == COMMANDS["ENV"].cmd:
|
||||
run_env()
|
||||
elif command_cli == COMMANDS["DEPLOY"].cmd:
|
||||
run_deploy()
|
||||
elif command_cli == COMMANDS["EXPORT"].cmd:
|
||||
run_export()
|
||||
elif not command_cli:
|
||||
print_help()
|
||||
else:
|
||||
print_help()
|
||||
raise ValueError(
|
||||
f"Currently command {command_cli} is not supported. Please refer to the table above to provide the correct command."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -44,10 +44,20 @@ from openmind.flow.model.adapter import apply_adapter
|
||||
from openmind.flow.model.sequence_parallel.seq_utils import apply_sequence_parallel
|
||||
from openmind.integrations.transformers.bitsandbytes import patch_bnb
|
||||
from openmind.utils.loader_utils import get_platform_loader
|
||||
from openmind.utils.arguments_utils import print_formatted_table
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def print_model_info():
|
||||
supported_models_info = list()
|
||||
for model_id, model_metadata in SUPPORTED_MODELS.items():
|
||||
model_info = {"model_id": model_id}
|
||||
model_info.update({platform: path for platform, path in model_metadata.path.items()})
|
||||
supported_models_info.append(model_info)
|
||||
print_formatted_table(supported_models_info, "keys")
|
||||
|
||||
|
||||
def try_download_from_hub() -> str:
|
||||
args = get_args()
|
||||
|
||||
@ -58,7 +68,18 @@ def try_download_from_hub() -> str:
|
||||
raise ValueError("Please set 'model_id' or 'model_name_or_path' to load model.")
|
||||
|
||||
if args.model_id is not None and args.model_name_or_path is None:
|
||||
args.model_name_or_path = SUPPORTED_MODELS[args.model_id].path[openmind_platform]
|
||||
try:
|
||||
args.model_name_or_path = SUPPORTED_MODELS[args.model_id].path[openmind_platform]
|
||||
except KeyError as e:
|
||||
print_model_info()
|
||||
if e.args[0] == args.model_id:
|
||||
raise ValueError(
|
||||
"The model_id is not in supported models. Please refer to the table above to provide the correct model_id."
|
||||
) from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model is not supported for download on the current platform. Please refer to the table above to provide the correct environment variable for `OPENMIND_PLATFORM`."
|
||||
) from e
|
||||
|
||||
if os.path.exists(args.model_name_or_path):
|
||||
return args.model_name_or_path
|
||||
|
@ -19,6 +19,8 @@ import argparse
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
def _trans_args_list_to_dict(args_list: list) -> dict:
|
||||
|
||||
@ -87,3 +89,7 @@ def safe_load_yaml(path):
|
||||
content = yaml.safe_load(file)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def print_formatted_table(data, header, missingval="N/A"):
|
||||
print(tabulate(data, header, missingval=missingval, tablefmt="fancy_grid"))
|
||||
|
@ -15,20 +15,32 @@ from enum import Enum
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Command:
|
||||
LIST = "list"
|
||||
EVAL = "eval"
|
||||
PULL = "pull"
|
||||
PUSH = "push"
|
||||
RM = "rm"
|
||||
CHAT = "chat"
|
||||
RUN = "run"
|
||||
ENV = "env"
|
||||
DEPLOY = "deploy"
|
||||
TRAIN = "train"
|
||||
EXPORT = "export"
|
||||
@dataclass
|
||||
class CommandInfo:
|
||||
cmd: str
|
||||
desc: str
|
||||
|
||||
|
||||
COMMANDS = {
|
||||
"LIST": CommandInfo(
|
||||
"list", "Query and list locally downloaded models in the model cache directory and specified download directory"
|
||||
),
|
||||
"EVAL": CommandInfo("eval", "Evaluate models"),
|
||||
"PULL": CommandInfo(
|
||||
"pull", "Download a specified model, dataset, or space to the cache directory or local directory"
|
||||
),
|
||||
"PUSH": CommandInfo("push", "Upload the content in the specified directory to the specified repository"),
|
||||
"RM": CommandInfo("rm", "Remove a specified model"),
|
||||
"CHAT": CommandInfo("chat", "Start a multi-turn dialog"),
|
||||
"RUN": CommandInfo("run", "Perform single-run inference"),
|
||||
"ENV": CommandInfo("env", "List the current operating environment (installed dependency libraries)"),
|
||||
"DEPLOY": CommandInfo("deploy", "Deploy server"),
|
||||
"TRAIN": CommandInfo("train", "Train models"),
|
||||
"EXPORT": CommandInfo("export", "Merge LoRA adapters and export model"),
|
||||
}
|
||||
|
||||
|
||||
class Stages:
|
||||
|
Reference in New Issue
Block a user