mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
4 Commits
d31f7844f8
...
woosuk/tes
Author | SHA1 | Date | |
---|---|---|---|
69c9a01538 | |||
8935ca208d | |||
dddad8a81c | |||
7f783b8a4a |
@ -649,5 +649,65 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_data_parallel_rank_extraction():
|
||||
"""Test that data_parallel_rank is properly extracted from header and passed to engine."""
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=MockModelConfig())
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
# Test when data_parallel_rank is present in header
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
)
|
||||
|
||||
# Mock request with X-data-parallel-rank header
|
||||
mock_raw_request = MagicMock()
|
||||
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
|
||||
mock_raw_request.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||
|
||||
# Verify that data_parallel_rank was passed to engine.generate
|
||||
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] == 2
|
||||
|
||||
# Test when data_parallel_rank is not present (defaults to None)
|
||||
req_no_dp = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 2+2?"
|
||||
}],
|
||||
)
|
||||
|
||||
# Mock request with no header
|
||||
mock_raw_request_no_dp = MagicMock()
|
||||
mock_raw_request_no_dp.headers = {}
|
||||
mock_raw_request_no_dp.state = MagicMock()
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
|
||||
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] is None
|
||||
|
@ -386,6 +386,24 @@ async def get_server_load_metrics(request: Request):
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
|
||||
@router.get("/get_server_info")
|
||||
async def get_server_info(raw_request: Request):
|
||||
"""Returns server information including DP size for router"""
|
||||
config = raw_request.app.state.vllm_config
|
||||
|
||||
# Extract dp_size from parallel_config
|
||||
dp_size = 1 # Default value
|
||||
if hasattr(config, 'parallel_config') and hasattr(config.parallel_config, 'data_parallel_size'):
|
||||
dp_size = config.parallel_config.data_parallel_size
|
||||
|
||||
server_info = {
|
||||
"vllm_config": str(config),
|
||||
"dp_size": dp_size
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
|
||||
@router.get("/ping", response_class=Response)
|
||||
@router.post("/ping", response_class=Response)
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
|
@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
@ -141,6 +141,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -224,6 +228,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
@ -1297,6 +1297,21 @@ class OpenAIServing:
|
||||
|
||||
return raw_request.headers.get("X-Request-Id", default)
|
||||
|
||||
@staticmethod
|
||||
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
|
||||
"""Pulls the data parallel rank from a header, if provided"""
|
||||
if raw_request is None:
|
||||
return None
|
||||
|
||||
rank_str = raw_request.headers.get("X-data-parallel-rank")
|
||||
if rank_str is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(rank_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
|
@ -36,9 +36,9 @@ def kernel_warmup(worker: "Worker"):
|
||||
max_tokens = worker.scheduler_config.max_num_batched_tokens
|
||||
deep_gemm_warmup(model, max_tokens)
|
||||
|
||||
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
flashinfer_autotune(worker.model_runner)
|
||||
# # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
|
||||
# if has_flashinfer() and current_platform.has_device_capability(90):
|
||||
# flashinfer_autotune(worker.model_runner)
|
||||
|
||||
# FlashInfer attention warmup
|
||||
# Only warmup if the model has FlashInfer attention groups
|
||||
|
Reference in New Issue
Block a user