[Core] Raise when non-multi-instance DP clients target a DP rank (#19227)

Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
jmswen
2025-06-06 04:03:01 -07:00
committed by GitHub
parent 7661e92ef8
commit 7353492a47
6 changed files with 77 additions and 12 deletions

View File

@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop):
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished
@pytest.mark.asyncio(scope="module")
async def test_invalid_argument(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
# Targeting specific DP rank only supported in v1 multi-instance DP
with pytest.raises(ValueError):
async for _ in async_engine.generate("test",
sampling_params,
request_id=uid(),
data_parallel_rank=0):
pass

View File

@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33)
# Test with valid DP rank.
async for _ in engine.generate(request_id="request-34",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=0):
pass
# Test with out-of-range DP rank.
with pytest.raises(ValueError):
async for _ in engine.generate(request_id="request-35",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=1):
pass

View File

@ -29,12 +29,14 @@ if not current_platform.supports_v1(engine_args.create_model_config()):
allow_module_level=True)
async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
async def generate(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None,
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
sampling_params=sampling_params,
data_parallel_rank=data_parallel_rank):
num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
generate(engine,
request_id,
prompt,
output_kind,
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)

View File

@ -494,6 +494,10 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
if data_parallel_rank is not None:
raise ValueError("Targeting data_parallel_rank only supported "
"in v1 client.")
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):

View File

@ -1000,9 +1000,6 @@ class DPAsyncMPClient(AsyncMPClient):
) -> CoreEngine:
if dp_rank is not None:
# engines are already in rank order
if dp_rank < 0 or dp_rank >= len(self.core_engines):
raise ValueError(f"Requested DP rank {dp_rank} is out of "
f"range [0, {len(self.core_engines)})")
return self.core_engines[dp_rank]
if not self.lb_engines:

View File

@ -226,6 +226,12 @@ class Processor:
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()