mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[FEAT] [Performance] Enable DP for ViT in Qwen2.5VL (#22742)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
@ -20,6 +21,8 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
|
||||
get_load_balance_assignment,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
@ -425,8 +428,8 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
@ -463,3 +466,322 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
|
||||
"expected_grouped_sizes_per_gpu,test_description",
|
||||
[
|
||||
# Empty input
|
||||
([], 2, [], [0, 0], [0, 0], "empty input"),
|
||||
|
||||
# Fewer samples than GPUs
|
||||
([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0
|
||||
], "fewer samples than GPUs"),
|
||||
|
||||
# Single GPU
|
||||
([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"),
|
||||
|
||||
# Balanced assignment
|
||||
([100, 100, 100, 100
|
||||
], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"),
|
||||
|
||||
# Unbalanced sizes - this one is trickier since the algorithm is greedy
|
||||
([1000, 100, 200, 50], 2, [0, 2, 1, 3
|
||||
], [1, 3], [1000, 350], "unbalanced sizes"),
|
||||
],
|
||||
)
|
||||
def test_get_load_balance_assignment_cases(sizes, num_gpus,
|
||||
expected_shuffle_indices,
|
||||
expected_gpu_sample_counts,
|
||||
expected_grouped_sizes_per_gpu,
|
||||
test_description):
|
||||
"""Test get_load_balance_assignment with various input cases."""
|
||||
result = get_load_balance_assignment(sizes, num_gpus=num_gpus)
|
||||
(shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result
|
||||
|
||||
# Common assertions for all cases
|
||||
assert len(shuffle_indices) == len(sizes)
|
||||
assert len(gpu_sample_counts) == num_gpus
|
||||
assert len(grouped_sizes_per_gpu) == num_gpus
|
||||
assert sum(gpu_sample_counts) == len(sizes)
|
||||
|
||||
assert shuffle_indices == expected_shuffle_indices
|
||||
|
||||
assert gpu_sample_counts == expected_gpu_sample_counts
|
||||
assert grouped_sizes_per_gpu == expected_grouped_sizes_per_gpu
|
||||
|
||||
|
||||
class SimpleMRopeVisionModel(torch.nn.Module):
|
||||
"""A simple vision model for testing mrope functionality."""
|
||||
|
||||
def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.linear = torch.nn.Linear(768, out_hidden_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]]):
|
||||
"""Simple forward pass that simulates spatial merging."""
|
||||
# Apply linear transformation
|
||||
embeddings = self.linear(pixel_values)
|
||||
|
||||
# Simulate spatial merging by reducing the number of patches
|
||||
merge_factor = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
# Group patches and merge spatially
|
||||
merged_embeddings = []
|
||||
start_idx = 0
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
end_idx = start_idx + num_patches
|
||||
|
||||
# Get patches for this image
|
||||
image_patches = embeddings[start_idx:end_idx]
|
||||
|
||||
# Simulate spatial merging by averaging groups of patches
|
||||
merged_patches = num_patches // merge_factor
|
||||
if merged_patches > 0:
|
||||
# Reshape and average to simulate merging
|
||||
reshaped = image_patches[:merged_patches * merge_factor].view(
|
||||
merged_patches, merge_factor, -1)
|
||||
merged = reshaped.mean(dim=1)
|
||||
merged_embeddings.append(merged)
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
if merged_embeddings:
|
||||
return torch.cat(merged_embeddings, dim=0)
|
||||
else:
|
||||
return torch.empty((0, self.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
3, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_mrope_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int,
|
||||
world_size: int,
|
||||
batch_size: int,
|
||||
master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_mrope_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
pixel_values_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Varying image sizes for better testing
|
||||
t, h, w = 1, 4 + i, 4 + i
|
||||
grid_thw_list.append([t, h, w])
|
||||
|
||||
num_patches = t * h * w
|
||||
# Create random pixel values for this image
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
# Concatenate all pixel values
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
|
||||
# Create a simple mrope vision model
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Run the model directly on the full input (only on rank 0)
|
||||
if local_rank == 0:
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_mrope_vision_model(
|
||||
vision_model, pixel_values, grid_thw_list)
|
||||
sharded_output = torch.cat(sharded_output, dim=0)
|
||||
|
||||
# Check that the world size is setup correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Compare outputs (only on rank 0)
|
||||
if local_rank == 0:
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output,
|
||||
sharded_output,
|
||||
rtol=1e-5,
|
||||
atol=1e-5)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_run_dp_sharded_mrope_vision_model_empty_input():
|
||||
world_size = 2
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_empty_input_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with empty input."""
|
||||
# Set up distributed environment
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
grid_thw_list: list[list[int]] = []
|
||||
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle empty input gracefully
|
||||
with torch.inference_mode():
|
||||
output = run_dp_sharded_mrope_vision_model(vision_model, pixel_values,
|
||||
grid_thw_list)
|
||||
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_run_dp_sharded_mrope_vision_model_uneven_load():
|
||||
world_size = 4
|
||||
mp.spawn(
|
||||
run_dp_sharded_mrope_vision_model_uneven_load_worker,
|
||||
args=(world_size, get_open_port()),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
local_rank: int, world_size: int, master_port: int):
|
||||
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
|
||||
# Set up distributed environment
|
||||
current_platform.seed_everything(123)
|
||||
device = f"{current_platform.device_name}:{local_rank}"
|
||||
current_platform.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
[1, 2, 2], # Small: 4 patches
|
||||
[1, 8, 8], # Large: 64 patches
|
||||
[1, 3, 3], # Medium: 9 patches
|
||||
]
|
||||
|
||||
pixel_values_list = []
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel()
|
||||
|
||||
# Should handle uneven distribution without errors
|
||||
with torch.inference_mode():
|
||||
output_tuple = run_dp_sharded_mrope_vision_model(
|
||||
vision_model, pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output shape is reasonable
|
||||
merge_factor = vision_model.spatial_merge_size**2
|
||||
expected_output_patches = list(
|
||||
math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list)
|
||||
|
||||
for i, output in enumerate(output_tuple):
|
||||
assert output.shape[0] == expected_output_patches[i]
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spatial_merge_size", [2, 4])
|
||||
def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int):
|
||||
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
|
||||
device = current_platform.device_type
|
||||
|
||||
grid_thw_list = [[1, 4, 4], [1, 6, 6]] # Two images
|
||||
pixel_values_list = []
|
||||
|
||||
for grid_thw in grid_thw_list:
|
||||
num_patches = math.prod(grid_thw)
|
||||
image_pixels = torch.randn(num_patches, 768, device=device)
|
||||
pixel_values_list.append(image_pixels)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
vision_model = SimpleMRopeVisionModel(
|
||||
spatial_merge_size=spatial_merge_size).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output = vision_model(pixel_values, grid_thw_list)
|
||||
|
||||
# Verify output dimensions based on spatial merging
|
||||
total_patches = sum(math.prod(grid_thw) for grid_thw in grid_thw_list)
|
||||
merge_factor = spatial_merge_size**2
|
||||
expected_output_patches = total_patches // merge_factor
|
||||
|
||||
assert output.shape[0] == expected_output_patches
|
||||
assert output.shape[1] == vision_model.out_hidden_size
|
||||
|
@ -437,7 +437,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
|
||||
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||
|
||||
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
|
@ -45,10 +45,14 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
@ -57,6 +61,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
@ -170,19 +175,25 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
|
||||
MergedColumnParallelLinear)
|
||||
self.gate_up_proj = cls_gate_up_proj(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
|
||||
cls_down_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.down_proj = cls_down_proj(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -220,28 +231,42 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
if use_data_parallel:
|
||||
self.qkv = ReplicatedLinear(embed_dim,
|
||||
self.hidden_size_per_attention_head *
|
||||
3 * num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
|
||||
else:
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
|
||||
cls_proj = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.proj = cls_proj(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
@ -302,8 +327,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
@ -370,23 +393,27 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Qwen2_5_VisionAttention(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
self.mlp = Qwen2_5_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -445,24 +472,30 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.ln_q = norm_layer(context_dim)
|
||||
|
||||
cls_fc1 = (ReplicatedLinear
|
||||
if use_data_parallel else ColumnParallelLinear)
|
||||
cls_fc2 = (ReplicatedLinear
|
||||
if use_data_parallel else RowParallelLinear)
|
||||
self.mlp = nn.ModuleList([
|
||||
ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.0"),
|
||||
cls_fc1(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.0"),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.2"),
|
||||
cls_fc2(self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp.2"),
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -514,6 +547,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -523,6 +557,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
depth = vision_config.depth
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
self.num_heads = vision_config.num_heads
|
||||
self.use_data_parallel = use_data_parallel
|
||||
self.out_hidden_size = vision_config.out_hidden_size
|
||||
|
||||
# args for get_window_index_thw
|
||||
self.window_size = vision_config.window_size
|
||||
@ -550,7 +586,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(depth)
|
||||
])
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
@ -560,6 +597,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
spatial_merge_size=self.spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
|
||||
@ -767,7 +805,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@ -840,6 +877,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
@ -851,6 +890,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
quant_config=self._maybe_ignore_quant_config(
|
||||
self.quant_config),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
else:
|
||||
self.visual = None
|
||||
@ -973,7 +1013,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"]
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values, grid_thw_list)
|
||||
else:
|
||||
image_embeds = self.visual(pixel_values,
|
||||
grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
||||
@ -995,8 +1041,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"]
|
||||
video_embeds = self.visual(pixel_values_videos,
|
||||
grid_thw=grid_thw_list)
|
||||
if self.use_data_parallel:
|
||||
return run_dp_sharded_mrope_vision_model(
|
||||
self.visual, pixel_values_videos, grid_thw_list)
|
||||
else:
|
||||
video_embeds = self.visual(pixel_values_videos,
|
||||
grid_thw=grid_thw_list)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
|
@ -329,8 +329,6 @@ class Qwen2VisionAttention(nn.Module):
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
|
@ -3,6 +3,8 @@
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import itertools
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
@ -465,6 +467,219 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
|
||||
return vision_embeddings
|
||||
|
||||
|
||||
def get_load_balance_assignment(
|
||||
sizes: list[int],
|
||||
num_gpus: int = 2,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Generate load balancing assignment and metadata
|
||||
for distributing data across GPUs.
|
||||
The load is determined by the total image sizes,
|
||||
not the number of images.
|
||||
|
||||
Args:
|
||||
sizes: The size of each image
|
||||
num_gpus: Number of GPUs to balance across
|
||||
|
||||
Returns:
|
||||
shuffle_indices:
|
||||
Indices to reorder data for balanced loading
|
||||
gpu_sample_counts:
|
||||
Number of samples assigned to each GPU
|
||||
grouped_sizes_per_gpu:
|
||||
Total size assigned to each GPU
|
||||
|
||||
Example:
|
||||
```
|
||||
sizes = [1000, 100, 200, 50]
|
||||
num_gpus=2
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
n_samples = len(sizes)
|
||||
|
||||
# Handle edge cases
|
||||
if n_samples == 0:
|
||||
return [], [0] * num_gpus, [0] * num_gpus
|
||||
|
||||
# Use greedy algorithm - balance by total size, not sample count
|
||||
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
||||
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
||||
|
||||
# Sort indices by size (largest first for better load balancing)
|
||||
# sizes = [1000, 100, 200, 50]
|
||||
# large_to_small_indices = [0, 2, 1, 3]
|
||||
large_to_small_indices = sorted(range(n_samples),
|
||||
key=lambda i: sizes[i],
|
||||
reverse=True)
|
||||
|
||||
for idx in large_to_small_indices:
|
||||
# Find GPU with minimum current load (by total size)
|
||||
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
||||
gpu_assignments[min_gpu].append(idx)
|
||||
gpu_loads[min_gpu] += sizes[idx]
|
||||
|
||||
# Create shuffle indices and counts
|
||||
shuffle_indices = list[int]()
|
||||
gpu_sample_counts = list[int]()
|
||||
for gpu_id in range(num_gpus):
|
||||
# GPU_0 = [1000] = [0]
|
||||
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
||||
# shuffle_indices = [0, 2, 1, 3]
|
||||
shuffle_indices.extend(gpu_assignments[gpu_id])
|
||||
# GPU_0 = [1]
|
||||
# GPU_1 = [3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
||||
|
||||
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
||||
|
||||
|
||||
def run_dp_sharded_mrope_vision_model(
|
||||
vision_model: torch.nn.Module,
|
||||
pixel_values: torch.Tensor,
|
||||
grid_thw_list: list[list[int]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Run a vision model with data parallelism (DP) sharding.
|
||||
The function will shard the input image tensor on the
|
||||
first dimension and run the vision model.
|
||||
This function is used to run the vision model with mrope.
|
||||
|
||||
Args:
|
||||
vision_model (torch.nn.Module): Vision model.
|
||||
pixel_values (torch.Tensor): Image/Video input tensor.
|
||||
grid_thw_list: List of grid dimensions for each image
|
||||
Returns:
|
||||
torch.Tensor: Output image embeddings
|
||||
|
||||
Example:
|
||||
```
|
||||
vision_model.out_hidden_size = 64
|
||||
vision_model.spatial_merge_size = 2
|
||||
pixel_values.shape = (1350, channel)
|
||||
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
||||
tp_size=2
|
||||
```
|
||||
|
||||
"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# GPU_0 tp_rank_local = 0
|
||||
# GPU_1 tp_rank_local = 1
|
||||
tp_rank_local = get_tensor_model_parallel_rank()
|
||||
|
||||
# patches_per_image = [1000, 100, 200, 50]
|
||||
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
||||
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
||||
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
||||
|
||||
# Get load balancing assignment with all metadata
|
||||
# image_to_tp_rank = [0, 2, 1, 3]
|
||||
# gpu_sample_counts = [1, 3]
|
||||
# grouped_pixel_values_len = [1000, 350]
|
||||
(image_to_tp_rank, gpu_sample_counts,
|
||||
grouped_pixel_values_len) = get_load_balance_assignment(
|
||||
patches_per_image, tp_size)
|
||||
|
||||
# cu_gpu_sample_counts = [0, 1, 4]
|
||||
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
||||
|
||||
# GPU_0 image_idxs_local = [0]
|
||||
# GPU_1 image_idxs_local = [2, 1, 3]
|
||||
image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]:
|
||||
cum_gpu_sample_counts[tp_rank_local +
|
||||
1]]
|
||||
|
||||
# Get the pixel values for the local images based on the image_idxs_local
|
||||
if len(image_idxs_local) > 0:
|
||||
pixel_values_local = torch.cat([
|
||||
pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]]
|
||||
for i in image_idxs_local
|
||||
])
|
||||
else:
|
||||
# Handle case where this rank has no images
|
||||
pixel_values_local = torch.empty((0, pixel_values.shape[1]),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
# embed_dim_reduction_factor = 2 * 2
|
||||
embed_dim_reduction_factor = (vision_model.spatial_merge_size *
|
||||
vision_model.spatial_merge_size)
|
||||
|
||||
# Find the max length across all ranks
|
||||
# The output embedding of every DP rank has to be
|
||||
# padded to this length for tensor_model_parallel_all_gather
|
||||
# to work
|
||||
max_len_per_rank = max(
|
||||
grouped_pixel_values_len) // embed_dim_reduction_factor
|
||||
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
||||
|
||||
# Run the vision model on the local pixel_values_local
|
||||
if pixel_values_local.shape[0] > 0:
|
||||
image_embeds_local = vision_model(pixel_values_local,
|
||||
local_grid_thw_list)
|
||||
else:
|
||||
# Handle empty case
|
||||
image_embeds_local = torch.empty((0, vision_model.out_hidden_size),
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype)
|
||||
|
||||
# Pad the output based on max_len_per_rank
|
||||
# for tensor_model_parallel_all_gather to work
|
||||
current_len = image_embeds_local.shape[0]
|
||||
if current_len < max_len_per_rank:
|
||||
padding_size = max_len_per_rank - current_len
|
||||
padding = torch.empty((padding_size, image_embeds_local.shape[1]),
|
||||
dtype=image_embeds_local.dtype,
|
||||
device=image_embeds_local.device)
|
||||
image_embeds_local_padded = torch.cat([image_embeds_local, padding],
|
||||
dim=0)
|
||||
else:
|
||||
image_embeds_local_padded = image_embeds_local
|
||||
|
||||
# Do all_gather to collect embeddings from all ranks
|
||||
gathered_embeds = tensor_model_parallel_all_gather(
|
||||
image_embeds_local_padded, dim=0)
|
||||
|
||||
# Remove padding and reconstruct per-rank embeddings
|
||||
rank_embeddings = list[torch.Tensor]()
|
||||
for rank in range(tp_size):
|
||||
start_idx = rank * max_len_per_rank
|
||||
end_idx = start_idx + (grouped_pixel_values_len[rank] //
|
||||
embed_dim_reduction_factor)
|
||||
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
||||
|
||||
patches_per_output_image = [(patch_size // embed_dim_reduction_factor)
|
||||
for patch_size in patches_per_image]
|
||||
|
||||
# Reconstruct embeddings in the original order
|
||||
original_order_embeddings = [None] * len(grid_thw_list)
|
||||
current_idx = 0
|
||||
for rank in range(tp_size):
|
||||
count = gpu_sample_counts[rank]
|
||||
if count > 0:
|
||||
# Get images assigned to this rank in shuffled order
|
||||
# GPU_0 = image_idxs_local [0]
|
||||
# GPU_1 = image_idxs_local [2, 1, 3]
|
||||
rank_images = image_to_tp_rank[current_idx:current_idx + count]
|
||||
|
||||
rank_embed = rank_embeddings[rank]
|
||||
# Split rank embeddings back to individual images
|
||||
embed_start = 0
|
||||
for img_idx in rank_images:
|
||||
img_patches = patches_per_output_image[img_idx]
|
||||
original_order_embeddings[img_idx] = rank_embed[
|
||||
embed_start:embed_start + img_patches]
|
||||
embed_start += img_patches
|
||||
current_idx += count
|
||||
|
||||
out_embeddings = tuple(embed for embed in original_order_embeddings
|
||||
if embed is not None)
|
||||
assert len(out_embeddings) == len(
|
||||
original_order_embeddings), "Found unassigned embeddings"
|
||||
return out_embeddings
|
||||
|
||||
|
||||
def fetch_audio(
|
||||
audio_url: str,
|
||||
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
||||
|
Reference in New Issue
Block a user