Unit Test for run_dp_sharded_vision_model (#19103)

Signed-off-by: Siqi Yan <siqi@meta.com>
Co-authored-by: Siqi Yan <siqi@meta.com>
This commit is contained in:
Siqi Yan
2025-06-06 01:24:02 -07:00
committed by GitHub
parent da511d54d8
commit f168b85725

View File

@ -9,12 +9,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from PIL import Image, ImageChops
from tests.utils import multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata)
merge_and_sort_multimodal_metadata,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
class SimpleLinearModel(torch.nn.Module):
"""A simple linear vision model for testing."""
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
super().__init__()
self.flatten = torch.nn.Flatten()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor):
# Flatten the input and apply linear transformation
x = self.flatten(x)
return self.linear(x)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"batch_size",
[
1, # Single image
4, # Small batch
5, # Odd batch size (for testing padding)
],
)
def test_run_dp_sharded_vision_model(batch_size: int):
world_size = 2
# Launch processes
mp.spawn(
run_dp_sharded_vision_model_vs_direct,
args=(
world_size,
batch_size,
get_open_port(),
),
nprocs=world_size,
)
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
batch_size: int, master_port: int):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': str(master_port),
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
# Create a simple linear model
vision_model = SimpleLinearModel()
# Run the model directly on the full input
with torch.inference_mode():
direct_output = vision_model(image_input)
# Run the model through the sharded function
with torch.inference_mode():
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
# Check that the world size is setup correctly
assert get_tensor_model_parallel_world_size() == world_size
# Check that the outputs have the same shape
assert direct_output.shape == sharded_output.shape
# Check that the outputs are close (they should be identical)
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)