mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[V1][PP] Support pp with ray backend in V1 (#1800)
### What this PR does / why we need it?
Support pipeline parallel with ray backend in V1Engine.
Fixes #1751
### Does this PR introduce _any_ user-facing change?
Users could specify ray as distributed backend when inferencing with pp
### How was this patch tested?
CI passed with new added test.
- vLLM version: v0.9.2
- vLLM main:
32142b3c62
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@ -6,7 +6,6 @@ pytest >= 6.0
|
||||
pytest-asyncio
|
||||
pytest-mock
|
||||
lm-eval
|
||||
ray
|
||||
types-jsonschema
|
||||
xgrammar
|
||||
zmq
|
||||
@ -14,3 +13,5 @@ types-psutil
|
||||
pytest-cov
|
||||
regex
|
||||
sentence_transformers
|
||||
ray>=2.47.1
|
||||
protobuf==4.25.6
|
||||
|
@ -24,6 +24,7 @@ MODELS = [
|
||||
|
||||
TENSOR_PARALLELS = [2]
|
||||
PIPELINE_PARALLELS = [2]
|
||||
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -34,10 +35,13 @@ prompts = [
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
|
||||
def test_models(model: str, tp_size: int, pp_size: int) -> None:
|
||||
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
|
||||
def test_models(model: str, tp_size: int, pp_size: int,
|
||||
distributed_executor_backend: str) -> None:
|
||||
with VllmRunner(model,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_model.generate_greedy(prompts, 64)
|
||||
|
@ -400,19 +400,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
layer = self.layer_no_quant
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
def mock_tensor(data, device=None, **kwargs):
|
||||
if device == "npu":
|
||||
return metadata.attn_mask
|
||||
return torch.tensor(data, **kwargs)
|
||||
|
||||
with patch("torch.tensor", side_effect=mock_tensor):
|
||||
output = self.impl_192.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
output = self.impl_192.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
@ -396,8 +396,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if self.head_size == 192:
|
||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
|
||||
cu_seqlen_q = torch.tensor(cu_seqlen_q,
|
||||
device=query.device)
|
||||
cu_seqlen_k = torch.tensor(cu_seqlen_k,
|
||||
device=query.device)
|
||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||
|
@ -233,7 +233,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to("npu")
|
||||
diagonal=1).to(self.device)
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
@ -1120,6 +1120,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
assert self.intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
||||
v[:num_input_tokens], non_blocking=True)
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_input_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
# Run forward pass
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
|
Reference in New Issue
Block a user