[Frontend] [Core] Support for sharded tensorized models (#4990)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Sanger Steel <sangersteel@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Travis Johnson
2024-06-12 15:13:52 -06:00
committed by GitHub
parent 5cc50a531f
commit 51602eefd3
6 changed files with 264 additions and 110 deletions

View File

@ -3,18 +3,12 @@ import dataclasses
import json
import os
import uuid
from functools import partial
from tensorizer import stream_io
from vllm import LLM
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
serialize_vllm_model)
tensorize_vllm_model)
# yapf conflicts with isort for this docstring
# yapf: disable
@ -61,6 +55,12 @@ Which downloads the model tensors from your S3 bucket and deserializes them.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors
For more information on the available arguments for serializing, run
`python -m examples.tensorize_vllm_model serialize --help`.
@ -168,77 +168,72 @@ def parse_args():
def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm
if __name__ == '__main__':
args = parse_args()
args = parse_args()
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))
credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}
credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}
model_ref = args.model
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_name = model_ref.split("/")[1]
model_ref = args.model
keyfile = args.keyfile if args.keyfile else None
model_name = model_ref.split("/")[1]
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
**credentials)
serialize_vllm_model(engine, tensorizer_config, keyfile)
elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
engine_args = EngineArgs.from_cli_args(
argparse.Namespace(**eng_args_dict)
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1:
model_path = f"{base_path}/model-rank-%03d.tensors"
else:
model_path = f"{base_path}/model.tensors"
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)
tensorize_vllm_model(engine_args, tensorizer_config)
elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -1,21 +1,27 @@
import json
import os
import pathlib
import subprocess
from unittest.mock import MagicMock, patch
import openai
import pytest
import ray
import torch
from tensorizer import EncryptionParams
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
load_with_tensorizer,
open_stream,
serialize_vllm_model)
serialize_vllm_model,
tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup
from ..utils import ServerRunner
# yapf conflicts with isort for this docstring
@ -42,6 +48,20 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model
def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key)
@pytest.fixture(autouse=True)
def tensorizer_config():
@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
write_keyfile(key_path)
outputs = vllm_model.generate(prompts, sampling_params)
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
config_for_serializing = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path
)
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
with vllm_runner(
@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
model_loader_extra_config = {
@ -224,7 +249,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))
def test_tensorizer_with_tp(vllm_runner):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
@ -238,8 +265,62 @@ def test_tensorizer_with_tp(vllm_runner):
s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2,
disable_custom_all_reduce=True,
)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path):
model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model = vllm_runner(
model_ref,
disable_custom_all_reduce=True,
enforce_eager=True,
)
outputs = base_model.generate(prompts, sampling_params)
base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
ray.shutdown()
# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
key_path = tmp_path / (model_ref + ".key")
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)
tensorize_vllm_model(
engine_args=EngineArgs(
model=model_ref,
tensor_parallel_size=2,
disable_custom_all_reduce=True,
enforce_eager=True,
),
tensorizer_config=tensorizer_config,
)
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
ray.shutdown()
loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
assert outputs == deserialized_outputs
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_ref = "facebook/opt-125m"
@ -248,7 +329,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
serialize_vllm_model(get_torch_model(vllm_model), config)
assert is_vllm_tensorized(config)

View File

@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
tensorizer_weights_iterator)
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
@ -392,6 +392,12 @@ class TensorizerLoader(BaseModelLoader):
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \
% get_tensor_model_parallel_rank()
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
@ -402,6 +408,16 @@ class TensorizerLoader(BaseModelLoader):
vision_language_config,
cache_config)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
) -> None:
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)
class ShardedStateLoader(BaseModelLoader):
"""

View File

@ -2,11 +2,11 @@ import argparse
import dataclasses
import io
import os
import re
import time
import typing
from dataclasses import dataclass
from functools import partial
from typing import Generator, Optional, Tuple, Type, Union
from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
import torch
from torch import nn
@ -14,6 +14,7 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
@ -48,8 +49,7 @@ logger = init_logger(__name__)
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
tensorizer_uri: str
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
@ -60,6 +60,12 @@ class TensorizerConfig:
model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
_is_sharded: bool = False
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
@ -78,13 +84,12 @@ class TensorizerConfig:
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
if parallel_config.tensor_parallel_size > 1 \
and not self._is_sharded:
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
@ -102,8 +107,8 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig,
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
@ -332,6 +337,7 @@ class TensorizerAgent:
) as stream, TensorDeserializer(
stream,
dtype=self.tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}',
**self.tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model)
end = time.perf_counter()
@ -400,33 +406,70 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
return False
def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module:
model = (engine.model_executor.driver_worker.model_runner.model)
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
return model
def serialize_vllm_model(engine: "LLMEngine",
tensorizer_config : TensorizerConfig,
encryption_key_path: Optional[str] = None) \
-> nn.Module:
model = get_pretensorized_vllm_model(engine)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
if encryption_key_path is not None:
encryption_params = EncryptionParams.random()
with _write_stream(encryption_key_path,
**tensorizer_args.stream_params) as stream:
stream.write(encryption_params.key)
with _write_stream(tensorizer_args.tensorizer_uri,
**tensorizer_args.stream_params) as stream:
encryption_params = None
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
with open(keyfile, "rb") as f:
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
logger.info("Successfully serialized model to %s",
str(tensorizer_args.tensorizer_uri))
logger.info("Successfully serialized model to %s", str(output_file))
return model
def tensorize_vllm_model(engine_args: EngineArgs,
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(
engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random()
with _write_stream(
keyfile,
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
engine = LLMEngine.from_engine_args(engine_args)
if tensorizer_config._is_sharded:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine.model_executor._run_workers(
"save_tensorized_model",
tensorizer_config=tensorizer_config,
)
else:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model(
engine.model_executor.driver_worker.model_runner.model,
tensorizer_config,
)

View File

@ -20,6 +20,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
@ -222,6 +223,16 @@ class ModelRunner:
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
from vllm.model_executor.model_loader.loader import TensorizerLoader
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
)
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size

View File

@ -15,6 +15,7 @@ from vllm.distributed import (broadcast_tensor_dict,
set_custom_all_reduce)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
@ -132,6 +133,13 @@ class Worker(WorkerBase):
max_size=max_size,
)
def save_tensorized_model(
self,
tensorizer_config: TensorizerConfig,
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many