[V1][PP] Support pp with ray backend in V1 (#1800)

### What this PR does / why we need it?
Support pipeline parallel with ray backend in V1Engine.

Fixes #1751

### Does this PR introduce _any_ user-facing change?
Users could specify ray as distributed backend when inferencing with pp

### How was this patch tested?
CI passed with new added test.


- vLLM version: v0.9.2
- vLLM main:
32142b3c62

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-07-23 14:52:52 +08:00
committed by GitHub
parent 9a3bdf2162
commit 3aa3b46bfe
5 changed files with 32 additions and 18 deletions

View File

@ -6,7 +6,6 @@ pytest >= 6.0
pytest-asyncio
pytest-mock
lm-eval
ray
types-jsonschema
xgrammar
zmq
@ -14,3 +13,5 @@ types-psutil
pytest-cov
regex
sentence_transformers
ray>=2.47.1
protobuf==4.25.6

View File

@ -24,6 +24,7 @@ MODELS = [
TENSOR_PARALLELS = [2]
PIPELINE_PARALLELS = [2]
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
prompts = [
"Hello, my name is",
@ -34,10 +35,13 @@ prompts = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
def test_models(model: str, tp_size: int, pp_size: int) -> None:
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
def test_models(model: str, tp_size: int, pp_size: int,
distributed_executor_backend: str) -> None:
with VllmRunner(model,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy(prompts, 64)

View File

@ -400,19 +400,13 @@ class TestAscendAttentionBackendImpl(TestBase):
layer = self.layer_no_quant
mock_vanilla_prefill.return_value = MagicMock()
def mock_tensor(data, device=None, **kwargs):
if device == "npu":
return metadata.attn_mask
return torch.tensor(data, **kwargs)
with patch("torch.tensor", side_effect=mock_tensor):
output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)

View File

@ -396,8 +396,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
if self.head_size == 192:
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
cu_seqlen_q = torch.tensor(cu_seqlen_q,
device=query.device)
cu_seqlen_k = torch.tensor(cu_seqlen_k,
device=query.device)
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
max_seqlen_q = torch.max(attn_metadata.query_lens)

View File

@ -233,7 +233,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to("npu")
diagonal=1).to(self.device)
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
@ -1120,6 +1120,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
v[:num_input_tokens], non_blocking=True)
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,