Add Model Revision Support (#1014)

Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Jasmond L
2023-09-14 06:20:02 +08:00
committed by GitHub
parent 9841d48a10
commit ab019eea75
20 changed files with 75 additions and 35 deletions

View File

@ -38,6 +38,9 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
"""
@ -52,6 +55,7 @@ class ModelConfig:
load_format: str,
dtype: str,
seed: int,
revision: Optional[str],
max_model_len: Optional[int] = None,
) -> None:
self.model = model
@ -61,8 +65,9 @@ class ModelConfig:
self.download_dir = download_dir
self.load_format = load_format
self.seed = seed
self.revision = revision
self.hf_config = get_config(model, trust_remote_code)
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()

View File

@ -28,6 +28,7 @@ class EngineArgs:
max_num_batched_tokens: int = 2560
max_num_seqs: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
def __post_init__(self):
if self.tokenizer is None:
@ -49,6 +50,13 @@ class EngineArgs:
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
@ -159,7 +167,8 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.max_model_len)
self.dtype, self.seed, self.revision,
self.max_model_len)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)

View File

@ -74,6 +74,7 @@ class LLMEngine:
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"download_dir={model_config.download_dir!r}, "
@ -92,7 +93,8 @@ class LLMEngine:
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision)
self.seq_counter = Counter()
# Create the parallel GPU workers.

View File

@ -38,6 +38,8 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
"""
def __init__(

View File

@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format)
model_config.load_format, model_config.revision)
model = model.cuda()
return model.eval()

View File

@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.

View File

@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size()

View File

@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.

View File

@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.

View File

@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name:
continue

View File

@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue

View File

@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].

View File

@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto"):
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name:
continue

View File

@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue

View File

@ -83,6 +83,7 @@ def prepare_hf_model_weights(
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
):
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
@ -94,7 +95,8 @@ def prepare_hf_model_weights(
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
tqdm_class=Disabledtqdm,
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
@ -107,7 +109,8 @@ def prepare_hf_model_weights(
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensors=False,
fall_back_to_pt=False)
fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0:
raise RuntimeError(
@ -120,6 +123,7 @@ def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False
use_np_cache = False
@ -140,7 +144,8 @@ def hf_model_weights_iterator(
model_name_or_path,
cache_dir=cache_dir,
use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt)
fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache:
# Currently np_cache only support *.bin checkpoints

View File

@ -1,3 +1,5 @@
from typing import Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
@ -12,10 +14,12 @@ _CONFIG_REGISTRY = {
}
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
model, trust_remote_code=trust_remote_code, revision=revision)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)
config = config_class.from_pretrained(model, revision=revision)
return config