mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Raise when non-multi-instance DP clients target a DP rank (#19227)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop):
|
||||
assert final_output is not None
|
||||
assert len(final_output.outputs[0].token_ids) == 10
|
||||
assert final_output.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_invalid_argument(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
pytest.skip("no need to test this one with multistep")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
# Targeting specific DP rank only supported in v1 multi-instance DP
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in async_engine.generate("test",
|
||||
sampling_params,
|
||||
request_id=uid(),
|
||||
data_parallel_rank=0):
|
||||
pass
|
||||
|
@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
|
||||
assert len(engine.stat_loggers) == 1
|
||||
assert len(engine.stat_loggers[0]) == 1
|
||||
engine.stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33)
|
||||
|
||||
# Test with valid DP rank.
|
||||
async for _ in engine.generate(request_id="request-34",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=0):
|
||||
pass
|
||||
|
||||
# Test with out-of-range DP rank.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in engine.generate(request_id="request-35",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=1):
|
||||
pass
|
||||
|
@ -29,12 +29,14 @@ if not current_platform.supports_v1(engine_args.create_model_config()):
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
|
||||
async def generate(
|
||||
engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=data_parallel_rank):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind,
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
generate(engine,
|
||||
request_id,
|
||||
prompt,
|
||||
output_kind,
|
||||
NUM_EXPECTED_TOKENS,
|
||||
data_parallel_rank=0)))
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
|
@ -494,6 +494,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if data_parallel_rank is not None:
|
||||
raise ValueError("Targeting data_parallel_rank only supported "
|
||||
"in v1 client.")
|
||||
|
||||
if (isinstance(prompt, dict)
|
||||
and prompt.get("prompt_embeds", None) is not None
|
||||
and not prompt.get("prompt_token_ids", None)):
|
||||
|
@ -1000,9 +1000,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
) -> CoreEngine:
|
||||
if dp_rank is not None:
|
||||
# engines are already in rank order
|
||||
if dp_rank < 0 or dp_rank >= len(self.core_engines):
|
||||
raise ValueError(f"Requested DP rank {dp_rank} is out of "
|
||||
f"range [0, {len(self.core_engines)})")
|
||||
return self.core_engines[dp_rank]
|
||||
|
||||
if not self.lb_engines:
|
||||
|
@ -226,6 +226,12 @@ class Processor:
|
||||
if prompt_adapter_request is not None:
|
||||
raise ValueError("V1 does not support prompt_adapter_request.")
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
data_parallel_size):
|
||||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size}).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
|
Reference in New Issue
Block a user