[Deprecation] Remove prompt_token_ids
arg fallback in LLM.generate
and LLM.embed
(#18800)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
# We can use this script to compute baseline accuracy on GSM for transformers.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
@ -3,7 +3,7 @@
|
||||
# We use this for fp8, which HF does not support.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
|
@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio
|
||||
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
Load and run the model in `vllm`:
|
||||
|
@ -18,7 +18,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
@ -19,7 +19,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
@ -20,7 +20,7 @@ for more installation details.
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
@ -5,6 +5,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
try:
|
||||
@ -137,7 +138,8 @@ def main():
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||
TokensPrompt(prompt_token_ids=prompt_ids),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
@ -85,7 +85,7 @@ def format_output(title: str, output: str):
|
||||
|
||||
|
||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
tokenizers==0.21.1
|
||||
|
@ -32,7 +32,8 @@ num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
# TODO: Use lm-eval[api]==0.4.10 once released
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.55.2
|
||||
tokenizers==0.21.1
|
||||
|
@ -408,7 +408,7 @@ lightning-utilities==0.14.3
|
||||
# torchmetrics
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.8
|
||||
lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
|
||||
# via -r requirements/test.in
|
||||
lxml==5.3.0
|
||||
# via
|
||||
|
@ -18,10 +18,9 @@ def text_llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@ -88,10 +87,9 @@ def vision_llm():
|
||||
seed=0,
|
||||
)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@ -158,10 +156,9 @@ def thinking_llm():
|
||||
seed=0,
|
||||
)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -35,10 +35,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -26,10 +26,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -5,11 +5,9 @@ import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||
from vllm import LLM, PoolingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...models.utils import check_embeddings_close
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
|
||||
PROMPTS = [
|
||||
@ -48,57 +46,13 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_match(o1: list[PoolingRequestOutput],
|
||||
o2: list[PoolingRequestOutput]):
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[o.outputs.data for o in o1],
|
||||
embeddings_1_lst=[o.outputs.data for o in o2],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||
pooling_params=pooling_params)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_pooling_params(llm: LLM):
|
||||
pooling_params = [
|
||||
|
@ -5,7 +5,7 @@ import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "distilbert/distilgpt2"
|
||||
@ -41,50 +41,13 @@ def llm():
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_sampling_params(llm: LLM):
|
||||
sampling_params = [
|
||||
|
@ -48,10 +48,9 @@ def llm(request, monkeypatch_module):
|
||||
max_num_seqs=128,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -36,10 +36,9 @@ def llm():
|
||||
trust_remote_code=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -33,10 +33,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
||||
with vllm_runner(model_id) as llm:
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
|
@ -46,5 +46,5 @@ def test_lm_head(
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
vllm_model.generate_greedy(["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
|
@ -127,13 +127,15 @@ def test_structured_output(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
(f"Give an example JSON for an employee profile that fits this "
|
||||
f"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
prompt = ("Give an example JSON for an employee profile that fits this "
|
||||
"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
outputs = llm.generate(
|
||||
[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
@ -191,20 +193,24 @@ def test_structured_output(
|
||||
with pytest.raises(ValueError,
|
||||
match="The provided JSON schema contains features "
|
||||
"not supported by xgrammar."):
|
||||
|
||||
prompt = (f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {unsupported_json_schema}. "
|
||||
f"Make the response as short as possible.")
|
||||
llm.generate(
|
||||
prompts=[(f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {unsupported_json_schema}. "
|
||||
f"Make the response as short as possible.")] * 2,
|
||||
[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
use_tqdm=True,
|
||||
)
|
||||
else:
|
||||
outputs = llm.generate(prompts=(
|
||||
"Give an example JSON object for a grade "
|
||||
"that fits this schema: "
|
||||
f"{unsupported_json_schema}. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
prompt = (f"Give an example JSON object for a grade that "
|
||||
f"fits this schema: {unsupported_json_schema}. "
|
||||
f"Make the response as short as possible.")
|
||||
outputs = llm.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
@ -227,10 +233,9 @@ def test_structured_output(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -261,10 +266,9 @@ def test_structured_output(
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -301,7 +305,6 @@ def test_structured_output(
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
prompts=
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
@ -316,11 +319,11 @@ def test_structured_output(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
|
||||
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
|
||||
f"Make the response as short as possible.")
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
(f"Give an example IPv4 address with this regex: {sample_regex}. "
|
||||
f"Make the response as short as possible.")
|
||||
] * 2,
|
||||
[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
@ -343,11 +346,13 @@ def test_structured_output(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("The best language for type-safe systems programming is "
|
||||
"(Make the response as short as possible.) "),
|
||||
("The best language for type-safe systems programming is "
|
||||
"(Make the response as short as possible.) "),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
use_tqdm=True,
|
||||
)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
@ -367,12 +372,14 @@ def test_structured_output(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON with the brand, model and car_type of the most "
|
||||
"iconic car from the 90's. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
outputs = llm.generate(
|
||||
("Generate a JSON with the brand, model and car_type of the most "
|
||||
"iconic car from the 90's. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
@ -411,10 +418,11 @@ def test_structured_output(
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a description of a frog using 50 characters. "
|
||||
"Make the response as short as possible."),
|
||||
("Generate a description of a frog using 50 characters. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
@ -498,7 +506,7 @@ Make the response as short as possible.
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
outputs = llm.generate(prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
@ -639,15 +647,13 @@ def test_structured_output_auto_mode(
|
||||
f"{unsupported_json_schema}. Make the response as short as possible.")
|
||||
# This would fail with the default of "xgrammar", but in "auto"
|
||||
# we will handle fallback automatically.
|
||||
outputs = llm.generate(prompts=prompts,
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
# Make sure `auto` backend handling doesn't mess up sampling_params
|
||||
# and that we can reuse it without error.
|
||||
outputs.extend(
|
||||
llm.generate(prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True))
|
||||
llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True))
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
@ -705,7 +711,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
@ -3,15 +3,13 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||||
cast, overload)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
@ -40,7 +38,6 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||
log_non_default_args)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -54,7 +51,7 @@ from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
|
||||
from vllm.utils import Counter, Device, is_list_of
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -157,18 +154,6 @@ class LLM:
|
||||
serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
|
||||
"""
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = True
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def deprecate_legacy_api(cls):
|
||||
cls.DEPRECATE_LEGACY = True
|
||||
|
||||
yield
|
||||
|
||||
cls.DEPRECATE_LEGACY = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@ -325,99 +310,14 @@ class LLM:
|
||||
return SamplingParams.from_optional(**self.default_sampling_params)
|
||||
return SamplingParams()
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def generate(
|
||||
self,
|
||||
prompts: str,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[int],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[list[str]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[list[int]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def generate(
|
||||
self,
|
||||
prompts: None,
|
||||
sampling_params: None,
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, list[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
@ -460,15 +360,6 @@ class LLM:
|
||||
"Try passing `--runner generate` to use the model as a "
|
||||
"generative model.")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, list[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
@ -483,10 +374,10 @@ class LLM:
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(
|
||||
parsed_prompts, lora_request)
|
||||
prompts, lora_request)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
prompts=prompts,
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
@ -498,7 +389,7 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
|
||||
self, prompts: Union[PromptType, Sequence[PromptType]],
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
|
||||
# Grab the lora config off the vllm config on the engine,
|
||||
# since this is the same for both v0 & v1.
|
||||
@ -511,35 +402,33 @@ class LLM:
|
||||
or (lora_config and lora_config.default_mm_loras is None)):
|
||||
return lora_request
|
||||
|
||||
if not isinstance(parsed_prompts, Sequence):
|
||||
parsed_prompts = [parsed_prompts]
|
||||
if not isinstance(prompts, Sequence):
|
||||
prompts = [prompts]
|
||||
|
||||
optional_loras = ([lora_request] * len(parsed_prompts)
|
||||
optional_loras = ([lora_request] * len(prompts)
|
||||
if not isinstance(lora_request, Sequence) else
|
||||
lora_request)
|
||||
|
||||
return [
|
||||
self._resolve_single_prompt_mm_lora(
|
||||
parsed_prompt,
|
||||
prompt,
|
||||
opt_lora_req,
|
||||
lora_config.default_mm_loras,
|
||||
) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
|
||||
optional_loras)
|
||||
) for prompt, opt_lora_req in zip(prompts, optional_loras)
|
||||
]
|
||||
|
||||
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
|
||||
def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
|
||||
lora_request: Optional[LoRARequest],
|
||||
default_mm_loras: Optional[dict[str,
|
||||
str]]):
|
||||
if (not default_mm_loras or not isinstance(parsed_prompt, dict)
|
||||
or "multi_modal_data" not in parsed_prompt):
|
||||
if (not default_mm_loras or not isinstance(prompt, dict)
|
||||
or "multi_modal_data" not in prompt):
|
||||
return lora_request
|
||||
|
||||
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)
|
||||
prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
|
||||
|
||||
intersection = set(
|
||||
parsed_prompt["multi_modal_data"].keys()).intersection(
|
||||
default_mm_loras.keys())
|
||||
intersection = set(prompt["multi_modal_data"].keys()) \
|
||||
.intersection(default_mm_loras.keys())
|
||||
if not intersection:
|
||||
return lora_request
|
||||
if len(intersection) > 1:
|
||||
@ -933,11 +822,9 @@ class LLM:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
@ -946,107 +833,6 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def encode(
|
||||
self,
|
||||
prompts: str,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def encode(
|
||||
self,
|
||||
prompts: list[str],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[int],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[list[str]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: list[list[int]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
@deprecated("'prompt_token_ids' will become part of 'prompts'")
|
||||
def encode(
|
||||
self,
|
||||
prompts: None,
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, list[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: Optional[PoolingTask] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
prompts.
|
||||
@ -1108,15 +894,6 @@ class LLM:
|
||||
raise ValueError(
|
||||
f"pooling_task must be one of {self.supported_tasks}.")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, list[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
@ -1134,7 +911,7 @@ class LLM:
|
||||
tokenization_kwargs)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
@ -1148,7 +925,6 @@ class LLM:
|
||||
def embed(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
@ -1198,7 +974,6 @@ class LLM:
|
||||
def classify(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
@ -1348,7 +1123,7 @@ class LLM:
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
parsed_prompts = []
|
||||
prompts = list[PromptType]()
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
@ -1372,10 +1147,10 @@ class LLM:
|
||||
else:
|
||||
pooling_params_list.append(pooling_params)
|
||||
|
||||
parsed_prompts.append(engine_prompt)
|
||||
prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
prompts=prompts,
|
||||
params=pooling_params_list,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
@ -1585,48 +1360,6 @@ class LLM:
|
||||
assert isinstance(self.llm_engine, V1LLMEngine)
|
||||
return self.llm_engine.get_metrics()
|
||||
|
||||
# LEGACY
|
||||
def _convert_v1_inputs(
|
||||
self,
|
||||
prompts: Optional[Union[str, list[str]]],
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
|
||||
):
|
||||
# skip_tokenizer_init is now checked in engine
|
||||
|
||||
if prompts is None and prompt_token_ids is None:
|
||||
raise ValueError(
|
||||
"Either prompts or prompt_token_ids must be provided.")
|
||||
if prompts is not None and prompt_token_ids is not None \
|
||||
and len(prompts) != len(prompt_token_ids):
|
||||
raise ValueError(
|
||||
"The lengths of prompts and prompt_token_ids must be the same."
|
||||
)
|
||||
|
||||
if prompts is not None:
|
||||
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = [
|
||||
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
|
||||
]
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
elif prompt_token_ids is not None:
|
||||
num_requests = len(prompt_token_ids)
|
||||
parsed_prompts: list[PromptType] = []
|
||||
for i in range(num_requests):
|
||||
item: PromptType
|
||||
|
||||
if prompts is not None:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
elif prompt_token_ids is not None:
|
||||
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
parsed_prompts.append(item)
|
||||
|
||||
return parsed_prompts
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
|
Reference in New Issue
Block a user