Compare commits

...

3 Commits

Author SHA1 Message Date
8209f9057d i honestly can't believe i spelled it that way
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-07-04 15:14:03 -04:00
19c51c3439 merge main, add environment variable, factor into function
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-07-04 15:11:40 -04:00
14a6efb83e hack for topk ids
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-06-25 16:46:41 -04:00
2 changed files with 52 additions and 0 deletions

View File

@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_UNIFORM_RANDOM_TOPK_IDS: bool = False
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
@ -913,6 +914,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
# Use uniform random topk ids for perfect load balancing in expectation.
# Use it for analyzing performance when using --load-format=dummy.
# MoE layers will not produce the correct answer when it is set.
"VLLM_UNIFORM_RANDOM_TOPK_IDS":
lambda: os.environ.get("VLLM_UNIFORM_RANDOM_TOPK_IDS", "false").lower() in
("1", "true"),
# Regex timeout for use by the vLLM tool parsing plugins.
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),

View File

@ -1154,6 +1154,41 @@ class FusedMoE(torch.nn.Module):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
@staticmethod
def uniform_random_select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Number of tokens in the current batch
num_tokens = hidden_states.shape[0]
# Infer how many experts exist from the router-logit dimension
global_num_experts = router_logits.shape[-1]
# Choose a dtype for the indices
if indices_type is None:
indices_type = torch.long
# Random expert IDs, uniform in [0, global_num_experts)
topk_ids = torch.randint(
low=0,
high=global_num_experts,
size=(num_tokens, top_k),
dtype=indices_type,
device=hidden_states.device,
)
# All-ones weights
topk_weights = torch.ones(
(num_tokens, top_k),
dtype=torch.float32,
device=hidden_states.device,
)
return topk_weights, topk_ids
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
@ -1187,6 +1222,15 @@ class FusedMoE(torch.nn.Module):
"""
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# Uniform random topk ids for performance experiments,
# especially when using dummy weights.
if envs.VLLM_UNIFORM_RANDOM_TOPK_IDS:
return FusedMoE.uniform_random_select_experts(
hidden_states,
router_logits,
top_k,
indices_type=indices_type)
# DeepSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None