[Bugfix] Fix sequence parallelism bug when enable pipeline parallelism (#24021)

Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
cascade
2025-09-15 21:32:32 -07:00
committed by GitHub
parent 759ef49b15
commit 17871983a2
6 changed files with 135 additions and 42 deletions

View File

@ -235,7 +235,6 @@ def _compare_sp(
'level': 3,
'custom_ops': ["+rms_norm"],
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallelism': True,
'enable_fusion': enable_fusion,
@ -251,6 +250,8 @@ def _compare_sp(
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",

View File

@ -663,14 +663,29 @@ class GroupCoordinator:
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
@ -699,14 +714,23 @@ class GroupCoordinator:
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
tensor_keys = [
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
]
assert len(tensor_keys) == len(tensor_list)
for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
@ -725,14 +749,29 @@ class GroupCoordinator:
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: Optional[dict[str, bool]] = None,
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
all_gather_group: The group for the all-gather operation. If provided,
an optimization is enabled where each rank in the group sends a
slice of a tensor and the receiver reconstructs it using an
all-gather, which can improve performance. This is typically the
tensor-parallel group.
all_gather_tensors: A dictionary to specify which tensors should use
the all-gather optimization, which is only effective when
`all_gather_group` is provided. By default, this optimization is
on for any tensor whose size is divisible by the
`all_gather_group`'s world size. However, it should be disabled
for tensors that are not fully replicated across the group (e.g.,
the residual tensor when sequence parallelism is enabled). This
dictionary allows overriding the default behavior on a per-tensor
basis.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
@ -766,6 +805,8 @@ class GroupCoordinator:
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
if all_gather_tensors else use_all_gather
if use_all_gather:
orig_shape = tensor.shape

View File

@ -19,6 +19,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
logger = init_logger(__name__)
@ -107,18 +108,29 @@ class CPUWorker(Worker):
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(
num_scheduled_tokens)
all_gather_tensors = {
"residual":
not is_residual_scattered_for_sp(self.vllm_config,
num_input_tokens)
}
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
get_pp_group().send_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)
return None
assert isinstance(output, ModelRunnerOutput)

View File

@ -88,6 +88,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from .utils import (AttentionGroup, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
@ -1633,21 +1634,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert self.intermediate_tensors is not None
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.compilation_config.pass_config. \
enable_sequence_parallelism
if enabled_sp:
# When sequence parallelism is enabled, we always pad num_tokens
# to be a multiple of tensor_parallel_size (tp) earlier
assert num_tokens % tp == 0
is_residual_scattered = tp > 1 and enabled_sp \
and num_tokens % tp == 0
is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
# When sequence parallelism is enabled, the "residual" tensor is sharded
# across tensor parallel ranks, so each rank only needs its own slice.
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
is_scattered = k == "residual" and is_residual_scattered
is_scattered = k == "residual" and is_rs
copy_len = num_tokens // tp if is_scattered else \
num_tokens
self.intermediate_tensors[k][:copy_len].copy_(
@ -1655,8 +1649,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return IntermediateTensors({
k:
v[:num_tokens // tp]
if k == "residual" and is_residual_scattered else v[:num_tokens]
v[:num_tokens //
tp] if k == "residual" and is_rs else v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
@ -1741,6 +1735,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooler_output=pooler_output,
)
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
and hasattr(self, "cudagraph_batch_sizes")
and self.cudagraph_batch_sizes
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use CUDA graphs.
# Add padding to the batch size.
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if (self.compilation_config.pass_config.enable_sequence_parallelism
and tp_size > 1):
return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
@ -1750,24 +1763,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Optional[IntermediateTensors], dict[str, Any]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
@ -2108,8 +2104,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert not self.is_pooling_model
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual":
not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens)
}
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
hidden_states.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]

View File

@ -32,6 +32,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@ -428,10 +429,19 @@ class Worker(WorkerBase):
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(
num_scheduled_tokens)
all_gather_tensors = {
"residual":
not is_residual_scattered_for_sp(self.vllm_config,
num_input_tokens)
}
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
@ -444,7 +454,8 @@ class Worker(WorkerBase):
"external_launcher") and not get_pp_group().is_last_rank
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)
kv_connector_output = output.kv_connector_output
if not kv_connector_output:

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
@ -288,3 +288,28 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
num_input_tokens: int) -> bool:
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled, and the number of
input tokens is one of the compilation sizes.
"""
if not vllm_config.compilation_config.pass_config.\
enable_sequence_parallelism:
return False
tp = vllm_config.parallel_config.tensor_parallel_size
if tp == 1:
return False
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
# Currently, SP is only enabled for static size fx graphs.
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)