mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix KVConnector TP worker aggregation (#21473)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -16,7 +16,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -342,19 +343,20 @@ class Worker(WorkerBase):
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# finished_sending and finished_recving buffers.
|
||||
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if output.finished_sending or output.finished_recving:
|
||||
empty_output = copy.copy(empty_output)
|
||||
empty_output.finished_sending = output.finished_sending
|
||||
empty_output.finished_recving = output.finished_recving
|
||||
output = empty_output
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
output = new_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
# return output only from the driver worker
|
||||
return output if self.is_driver_worker else None
|
||||
return output
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
|
Reference in New Issue
Block a user