mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend] generation_config.json for maximum tokens(#12242)
Signed-off-by: Matthew Hendrey <matthew.hendrey@gmail.com> Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Yuan Tang <terrytangyuan@gmail.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: shangmingc <caishangming@linux.alibaba.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yuan Tang <terrytangyuan@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@ -103,6 +103,116 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# lower than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 10 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# higher than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 200 # Setting server-side max_tokens limit
|
||||
}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
req.max_tokens = 100
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
asyncio.run(serving_chat.create_chat_completion(req))
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
|
||||
def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
|
@ -910,12 +910,18 @@ class ModelConfig:
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"max_new_tokens",
|
||||
]
|
||||
if any(p in config for p in available_params):
|
||||
diff_sampling_param = {
|
||||
p: config.get(p)
|
||||
for p in available_params if config.get(p) is not None
|
||||
}
|
||||
# Huggingface definition of max_new_tokens is equivalent
|
||||
# to vLLM's max_tokens
|
||||
if "max_new_tokens" in diff_sampling_param:
|
||||
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
|
||||
"max_new_tokens")
|
||||
else:
|
||||
diff_sampling_param = {}
|
||||
return diff_sampling_param
|
||||
|
@ -939,7 +939,9 @@ class EngineArgs:
|
||||
"Defaults to None, will use the default generation config in vLLM. "
|
||||
"If set to 'auto', the generation config will be automatically "
|
||||
"loaded from model. If set to a folder path, the generation config "
|
||||
"will be loaded from the specified folder path.")
|
||||
"will be loaded from the specified folder path. If "
|
||||
"`max_new_tokens` is specified, then it sets a server-wide limit "
|
||||
"on the number of output tokens for all requests.")
|
||||
|
||||
parser.add_argument("--enable-sleep-mode",
|
||||
action="store_true",
|
||||
|
@ -380,13 +380,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
) -> BeamSearchParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
@ -406,11 +410,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
@ -740,13 +749,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
@ -764,11 +777,16 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Use minimum of context window, user request & server limit.
|
||||
max_tokens = min(
|
||||
val for val in (default_max_tokens, max_tokens,
|
||||
default_sampling_params.get("max_tokens", None))
|
||||
if val is not None)
|
||||
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
|
Reference in New Issue
Block a user