mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Support routing logic simulation (#21990)
Signed-off-by: Ming Yang <minos.future@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
171
tests/test_routing_simulator.py
Normal file
171
tests/test_routing_simulator.py
Normal file
@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test script for the token-to-expert routing simulator.
|
||||
|
||||
This script demonstrates how to use the routing simulator to test
|
||||
different routing strategies and analyze their performance, including
|
||||
integration tests with FusedMoE layer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import (
|
||||
DistributionBasedRouting, RoutingSimulator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
"""Fixture to provide the appropriate device for testing."""
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 16, 256])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 1024])
|
||||
@pytest.mark.parametrize("num_experts", [16, 128])
|
||||
@pytest.mark.parametrize("top_k", [1, 4])
|
||||
def test_basic_functionality(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
device,
|
||||
):
|
||||
"""Test basic functionality of the routing simulator."""
|
||||
# Test each routing strategy
|
||||
strategies = RoutingSimulator.get_available_strategies()
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
for strategy in strategies:
|
||||
# Simulate routing
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=strategy,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Check output shapes
|
||||
assert topk_weights.shape == (
|
||||
num_tokens,
|
||||
top_k,
|
||||
), f"Wrong weights shape for {strategy}"
|
||||
assert topk_ids.shape == (
|
||||
num_tokens,
|
||||
top_k,
|
||||
), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Check that expert IDs are valid
|
||||
assert (topk_ids.min()
|
||||
>= 0), f"Invalid expert ID (negative) for {strategy}"
|
||||
assert (topk_ids.max()
|
||||
< num_experts), f"Invalid expert ID (too large) for {strategy}"
|
||||
|
||||
|
||||
def test_routing_strategy_integration(monkeypatch, device):
|
||||
"""Test that the routing strategy environment variable works with
|
||||
FusedMoE."""
|
||||
pytest.importorskip("vllm.model_executor.layers.fused_moe.layer")
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
# Test parameters
|
||||
num_tokens = 32
|
||||
hidden_size = 16
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
|
||||
# Create test data
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
# Test different routing strategies
|
||||
strategies = RoutingSimulator.get_available_strategies()
|
||||
|
||||
for strategy in strategies:
|
||||
# Set environment variable
|
||||
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
|
||||
monkeypatch.setenv(env_name, strategy)
|
||||
|
||||
# Force reload of environment variable
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
indices_type=torch.long)
|
||||
|
||||
# Verify output shapes
|
||||
assert topk_weights.shape == (
|
||||
num_tokens, top_k), f"Wrong weights shape for {strategy}"
|
||||
assert topk_ids.shape == (num_tokens,
|
||||
top_k), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Verify expert IDs are valid
|
||||
assert topk_ids.min(
|
||||
) >= 0, f"Invalid expert ID (negative) for {strategy}"
|
||||
assert topk_ids.max(
|
||||
) < num_experts, f"Invalid expert ID (too large) for {strategy}"
|
||||
|
||||
|
||||
def test_distribution_based_routing_with_custom_strategy():
|
||||
"""Test registering and using DistributionBasedRouting with custom
|
||||
parameters."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Register custom distribution-based strategy
|
||||
custom_strategy = DistributionBasedRouting(distribution="normal",
|
||||
mean=2.0,
|
||||
std=0.5)
|
||||
RoutingSimulator.register_strategy("custom_normal", custom_strategy)
|
||||
|
||||
# Test data
|
||||
num_tokens = 60
|
||||
hidden_size = 48
|
||||
num_experts = 6
|
||||
top_k = 3
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
# Use the custom strategy
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="custom_normal",
|
||||
top_k=top_k)
|
||||
|
||||
# Check output shapes
|
||||
assert topk_weights.shape == (num_tokens, top_k)
|
||||
assert topk_ids.shape == (num_tokens, top_k)
|
||||
|
||||
# Check that expert IDs are valid
|
||||
assert topk_ids.min() >= 0
|
||||
assert topk_ids.max() < num_experts
|
||||
|
||||
|
||||
def test_instance_compatibility():
|
||||
"""Test that static methods work correctly."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Test static method directly
|
||||
hidden_states = torch.randn(10, 8, device=device)
|
||||
router_logits = torch.randn(10, 4, device=device)
|
||||
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="uniform_random",
|
||||
top_k=2)
|
||||
|
||||
assert topk_weights.shape == (10, 2)
|
||||
assert topk_ids.shape == (10, 2)
|
@ -989,6 +989,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
||||
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
||||
|
||||
# MoE routing strategy selector.
|
||||
# See `RoutingSimulator.get_available_strategies()` # for available
|
||||
# strategies.
|
||||
# Cutstom routing strategies can be registered by
|
||||
# RoutingSimulator.register_strategy()
|
||||
# Note: custom strategies may not produce correct model outputs
|
||||
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY":
|
||||
lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(),
|
||||
|
||||
# Regex timeout for use by the vLLM tool parsing plugins.
|
||||
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
|
||||
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
|
||||
|
@ -28,6 +28,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import (
|
||||
RoutingSimulator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -1362,6 +1364,16 @@ class FusedMoE(torch.nn.Module):
|
||||
"""
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
|
||||
# Check if we should use a routing simulation strategy
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
return RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
top_k=top_k,
|
||||
indices_type=indices_type)
|
||||
|
||||
# DeepSeekv2 uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
|
289
vllm/model_executor/layers/fused_moe/routing_simulator.py
Normal file
289
vllm/model_executor/layers/fused_moe/routing_simulator.py
Normal file
@ -0,0 +1,289 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Token-to-Expert Routing Simulator
|
||||
|
||||
This module provides a framework for simulating and testing different
|
||||
token-to-expert routing strategies for Mixture of Experts (MoE) models.
|
||||
It supports routing logic customization and includes example implementations
|
||||
like uniform random routing.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RoutingStrategy(ABC):
|
||||
"""Base class for token-to-expert routing strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def route_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route tokens to experts.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DistributionBasedRouting(RoutingStrategy):
|
||||
"""
|
||||
Distribution-based random routing strategy with configurable distributions.
|
||||
|
||||
This routing strategy randomly selects experts for each token based on
|
||||
different probability distributions. Currently supports uniform and normal
|
||||
distributions for testing different routing patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, distribution: str = "uniform", **distribution_params):
|
||||
"""
|
||||
Initialize distribution-based routing.
|
||||
|
||||
Args:
|
||||
distribution: Type of distribution to use for sampling
|
||||
- "uniform": Uniform distribution (default)
|
||||
- "normal": Normal/Gaussian distribution
|
||||
**distribution_params: Parameters specific to the
|
||||
chosen distribution
|
||||
For "uniform": No additional parameters needed
|
||||
For "normal": mean (default: 0.0), std (default: 1.0)
|
||||
"""
|
||||
self.distribution = distribution.lower()
|
||||
self.distribution_params = distribution_params
|
||||
|
||||
# Validate distribution and parameters
|
||||
self._validate_distribution_params()
|
||||
|
||||
def _validate_distribution_params(self):
|
||||
"""Validate distribution type and parameters."""
|
||||
valid_distributions = ["uniform", "normal"]
|
||||
|
||||
if self.distribution not in valid_distributions:
|
||||
raise ValueError(f"Unsupported distribution: {self.distribution}. "
|
||||
f"Supported distributions: {valid_distributions}")
|
||||
|
||||
# Set default parameters if not provided
|
||||
if self.distribution == "normal":
|
||||
self.distribution_params.setdefault("mean", 0.0)
|
||||
self.distribution_params.setdefault("std", 1.0)
|
||||
|
||||
def route_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Randomly select experts for each token using the specified distribution.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids) where:
|
||||
- topk_weights: Weights based on distribution sampling
|
||||
- topk_ids: Expert indices sampled from the distribution
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_experts = router_logits.shape[-1]
|
||||
|
||||
if indices_type is None:
|
||||
indices_type = torch.long
|
||||
|
||||
# Generate expert IDs based on the specified distribution
|
||||
topk_ids = self._sample_expert_ids(num_tokens, num_experts, top_k,
|
||||
hidden_states.device, indices_type)
|
||||
|
||||
# Generate weights based on the distribution
|
||||
topk_weights = self._generate_weights(num_tokens, top_k,
|
||||
hidden_states.device)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def _sample_expert_ids(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
device: torch.device,
|
||||
indices_type: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Sample expert IDs based on the specified distribution."""
|
||||
|
||||
if self.distribution == "uniform":
|
||||
# Uniform random sampling
|
||||
return torch.randint(
|
||||
low=0,
|
||||
high=num_experts,
|
||||
size=(num_tokens, top_k),
|
||||
dtype=indices_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
elif self.distribution == "normal":
|
||||
# For normal distribution, sample continuous values and map to
|
||||
# expert IDs
|
||||
continuous_samples = self._sample_continuous_distribution(
|
||||
num_tokens, top_k, device)
|
||||
|
||||
# Map continuous samples to expert indices
|
||||
# Normalize to [0, 1] range and scale to [0, num_experts)
|
||||
normalized_samples = self._normalize_samples(continuous_samples)
|
||||
expert_ids = (normalized_samples * num_experts).long()
|
||||
expert_ids = torch.clamp(expert_ids, 0, num_experts - 1)
|
||||
|
||||
return expert_ids.to(dtype=indices_type)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported distribution: {self.distribution}")
|
||||
|
||||
def _sample_continuous_distribution(self, num_tokens: int, top_k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""Sample from continuous distributions."""
|
||||
shape = (num_tokens, top_k)
|
||||
|
||||
if self.distribution == "normal":
|
||||
mean = self.distribution_params["mean"]
|
||||
std = self.distribution_params["std"]
|
||||
return torch.normal(mean, std, size=shape, device=device)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported continuous distribution: {self.distribution}")
|
||||
|
||||
def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize samples to [0, 1] range."""
|
||||
if self.distribution == "normal":
|
||||
# Use sigmoid to map normal distribution to [0, 1]
|
||||
return torch.sigmoid(samples)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported distribution for normalization: "
|
||||
f"{self.distribution}")
|
||||
|
||||
def _generate_weights(self, num_tokens: int, top_k: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""Generate weights based on the distribution."""
|
||||
if self.distribution == "uniform":
|
||||
# All-ones weights for uniform distribution
|
||||
return torch.ones(
|
||||
(num_tokens, top_k),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
elif self.distribution == "normal":
|
||||
# For normal distribution, generate weights from the same
|
||||
# distribution
|
||||
continuous_weights = self._sample_continuous_distribution(
|
||||
num_tokens, top_k, device)
|
||||
# Normalize to positive values and sum to 1
|
||||
weights = torch.abs(continuous_weights)
|
||||
weights = weights / weights.sum(dim=-1, keepdim=True)
|
||||
return weights
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported distribution for weight generation: "
|
||||
f"{self.distribution}")
|
||||
|
||||
def get_distribution_info(self) -> dict:
|
||||
"""Get information about the current distribution configuration."""
|
||||
return {
|
||||
"distribution": self.distribution,
|
||||
"parameters": self.distribution_params.copy()
|
||||
}
|
||||
|
||||
|
||||
class RoutingSimulator:
|
||||
"""
|
||||
Token-to-Expert Routing Simulator.
|
||||
|
||||
This class provides a framework for testing and comparing different
|
||||
routing strategies for MoE models. It can simulate routing behavior
|
||||
and collect statistics for analysis.
|
||||
"""
|
||||
|
||||
# Class-level registry of routing strategies
|
||||
_routing_strategies: dict[str, RoutingStrategy] = {
|
||||
# Basic routing strategies
|
||||
"uniform_random":
|
||||
DistributionBasedRouting(distribution="uniform", mean=0.0, std=1.0),
|
||||
"normal_routing":
|
||||
DistributionBasedRouting(distribution="normal", mean=0.0, std=1.0),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_strategy(cls, name: str, strategy: RoutingStrategy):
|
||||
"""
|
||||
Register a custom routing strategy.
|
||||
|
||||
Args:
|
||||
name: Name of the strategy
|
||||
strategy: RoutingStrategy instance
|
||||
"""
|
||||
cls._routing_strategies[name] = strategy
|
||||
|
||||
@classmethod
|
||||
def get_available_strategies(cls):
|
||||
"""
|
||||
Get list of available routing strategy names.
|
||||
|
||||
Returns:
|
||||
List of available strategy names
|
||||
"""
|
||||
return list(cls._routing_strategies.keys())
|
||||
|
||||
@staticmethod
|
||||
def simulate_routing(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
strategy_name: str,
|
||||
top_k: int,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Simulate token-to-expert routing using the specified strategy.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
strategy_name: Name of the routing strategy to use
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
if strategy_name not in RoutingSimulator._routing_strategies:
|
||||
raise ValueError(
|
||||
f"Unknown routing strategy: {strategy_name}. "
|
||||
f"Available strategies: "
|
||||
f"{list(RoutingSimulator._routing_strategies.keys())}")
|
||||
|
||||
strategy = RoutingSimulator._routing_strategies[strategy_name]
|
||||
return strategy.route_tokens(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
Reference in New Issue
Block a user