mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix sequence parallelism bug when enable pipeline parallelism (#24021)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@ -235,7 +235,6 @@ def _compare_sp(
|
||||
'level': 3,
|
||||
'custom_ops': ["+rms_norm"],
|
||||
'compile_sizes': [4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_sequence_parallelism': True,
|
||||
'enable_fusion': enable_fusion,
|
||||
@ -251,6 +250,8 @@ def _compare_sp(
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
|
@ -663,14 +663,29 @@ class GroupCoordinator:
|
||||
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
@ -699,14 +714,23 @@ class GroupCoordinator:
|
||||
# `send_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
for tensor in tensor_list:
|
||||
|
||||
tensor_keys = [
|
||||
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
|
||||
]
|
||||
assert len(tensor_keys) == len(tensor_list)
|
||||
|
||||
for key, tensor in zip(tensor_keys, tensor_list):
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
if (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0):
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
if use_all_gather:
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
@ -725,14 +749,29 @@ class GroupCoordinator:
|
||||
self,
|
||||
src: Optional[int] = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: Optional[dict[str, bool]] = None,
|
||||
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
|
||||
all_gather_group: The group for the all-gather operation. If provided,
|
||||
an optimization is enabled where each rank in the group sends a
|
||||
slice of a tensor and the receiver reconstructs it using an
|
||||
all-gather, which can improve performance. This is typically the
|
||||
tensor-parallel group.
|
||||
all_gather_tensors: A dictionary to specify which tensors should use
|
||||
the all-gather optimization, which is only effective when
|
||||
`all_gather_group` is provided. By default, this optimization is
|
||||
on for any tensor whose size is divisible by the
|
||||
`all_gather_group`'s world size. However, it should be disabled
|
||||
for tensors that are not fully replicated across the group (e.g.,
|
||||
the residual tensor when sequence parallelism is enabled). This
|
||||
dictionary allows overriding the default behavior on a per-tensor
|
||||
basis.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
|
||||
all_gather_size = (1 if all_gather_group is None else
|
||||
all_gather_group.world_size)
|
||||
all_gather_rank = (0 if all_gather_group is None else
|
||||
@ -766,6 +805,8 @@ class GroupCoordinator:
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0)
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
|
||||
if all_gather_tensors else use_all_gather
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
|
@ -19,6 +19,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -107,18 +108,29 @@ class CPUWorker(Worker):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(
|
||||
num_scheduled_tokens)
|
||||
all_gather_tensors = {
|
||||
"residual":
|
||||
not is_residual_scattered_for_sp(self.vllm_config,
|
||||
num_input_tokens)
|
||||
}
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
get_pp_group().send_tensor_dict(
|
||||
output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors)
|
||||
return None
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
|
@ -88,6 +88,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||
@ -1633,21 +1634,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert self.intermediate_tensors is not None
|
||||
|
||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
enabled_sp = self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism
|
||||
if enabled_sp:
|
||||
# When sequence parallelism is enabled, we always pad num_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier
|
||||
assert num_tokens % tp == 0
|
||||
is_residual_scattered = tp > 1 and enabled_sp \
|
||||
and num_tokens % tp == 0
|
||||
is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
|
||||
|
||||
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
||||
# across tensor parallel ranks, so each rank only needs its own slice.
|
||||
if sync_self:
|
||||
assert intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
is_scattered = k == "residual" and is_residual_scattered
|
||||
is_scattered = k == "residual" and is_rs
|
||||
copy_len = num_tokens // tp if is_scattered else \
|
||||
num_tokens
|
||||
self.intermediate_tensors[k][:copy_len].copy_(
|
||||
@ -1655,8 +1649,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return IntermediateTensors({
|
||||
k:
|
||||
v[:num_tokens // tp]
|
||||
if k == "residual" and is_residual_scattered else v[:num_tokens]
|
||||
v[:num_tokens //
|
||||
tp] if k == "residual" and is_rs else v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
@ -1741,6 +1735,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pooler_output=pooler_output,
|
||||
)
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
|
||||
and hasattr(self, "cudagraph_batch_sizes")
|
||||
and self.cudagraph_batch_sizes
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
|
||||
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if (self.compilation_config.pass_config.enable_sequence_parallelism
|
||||
and tp_size > 1):
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -1750,24 +1763,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Optional[IntermediateTensors], dict[str, Any]]:
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
@ -2108,8 +2104,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert not self.is_pooling_model
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
all_gather_tensors = {
|
||||
"residual":
|
||||
not is_residual_scattered_for_sp(
|
||||
self.vllm_config, num_input_tokens)
|
||||
}
|
||||
get_pp_group().send_tensor_dict(
|
||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||
hidden_states.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors)
|
||||
logits = None
|
||||
else:
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
|
@ -32,6 +32,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -428,10 +429,19 @@ class Worker(WorkerBase):
|
||||
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(
|
||||
num_scheduled_tokens)
|
||||
all_gather_tensors = {
|
||||
"residual":
|
||||
not is_residual_scattered_for_sp(self.vllm_config,
|
||||
num_input_tokens)
|
||||
}
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
@ -444,7 +454,8 @@ class Worker(WorkerBase):
|
||||
"external_launcher") and not get_pp_group().is_last_rank
|
||||
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors)
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
@ -288,3 +288,28 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
|
||||
num_input_tokens: int) -> bool:
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled, and the number of
|
||||
input tokens is one of the compilation sizes.
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism:
|
||||
return False
|
||||
|
||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
# Currently, SP is only enabled for static size fx graphs.
|
||||
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)
|
||||
|
Reference in New Issue
Block a user