mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Unit Test for run_dp_sharded_vision_model (#19103)
Signed-off-by: Siqi Yan <siqi@meta.com> Co-authored-by: Siqi Yan <siqi@meta.com>
This commit is contained in:
@ -9,12 +9,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
merge_and_sort_multimodal_metadata)
|
||||
merge_and_sort_multimodal_metadata,
|
||||
run_dp_sharded_vision_model)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.hasher import MultiModalHashDict
|
||||
@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
assert modalities == expected_modalities
|
||||
assert ranges == expected_ranges
|
||||
assert hashes == expected_hashes
|
||||
|
||||
|
||||
class SimpleLinearModel(torch.nn.Module):
|
||||
"""A simple linear vision model for testing."""
|
||||
|
||||
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
|
||||
super().__init__()
|
||||
self.flatten = torch.nn.Flatten()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Flatten the input and apply linear transformation
|
||||
x = self.flatten(x)
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size",
|
||||
[
|
||||
1, # Single image
|
||||
4, # Small batch
|
||||
5, # Odd batch size (for testing padding)
|
||||
],
|
||||
)
|
||||
def test_run_dp_sharded_vision_model(batch_size: int):
|
||||
world_size = 2
|
||||
# Launch processes
|
||||
mp.spawn(
|
||||
run_dp_sharded_vision_model_vs_direct,
|
||||
args=(
|
||||
world_size,
|
||||
batch_size,
|
||||
get_open_port(),
|
||||
),
|
||||
nprocs=world_size,
|
||||
)
|
||||
|
||||
|
||||
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
|
||||
batch_size: int, master_port: int):
|
||||
"""
|
||||
Test that run_dp_sharded_vision_model produces the same results as
|
||||
calling the model directly.
|
||||
"""
|
||||
|
||||
# Set random seed for reproducibility
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': str(master_port),
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
|
||||
# Create a simple linear model
|
||||
vision_model = SimpleLinearModel()
|
||||
|
||||
# Run the model directly on the full input
|
||||
with torch.inference_mode():
|
||||
direct_output = vision_model(image_input)
|
||||
|
||||
# Run the model through the sharded function
|
||||
with torch.inference_mode():
|
||||
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
|
||||
|
||||
# Check that the world size is setup correctly
|
||||
assert get_tensor_model_parallel_world_size() == world_size
|
||||
|
||||
# Check that the outputs have the same shape
|
||||
assert direct_output.shape == sharded_output.shape
|
||||
|
||||
# Check that the outputs are close (they should be identical)
|
||||
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)
|
||||
|
Reference in New Issue
Block a user