mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
1001 lines
38 KiB
Python
1001 lines
38 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import contextlib
|
|
import os
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.platforms import current_platform
|
|
|
|
hopper_only = pytest.mark.skipif(
|
|
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
|
|
reason="Requires CUDA and Hopper (SM90)",
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def enable_batch_invariant_mode():
|
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
|
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
|
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
|
yield
|
|
# Restore original value after test
|
|
if old_value is None:
|
|
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
else:
|
|
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
|
|
|
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
|
# Generate more realistic prompts that will actually produce varied tokens
|
|
# Use a mix of common English text patterns
|
|
|
|
prompt_templates = [
|
|
# Question-answer style
|
|
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
|
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
|
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
|
# Story/narrative style
|
|
"Once upon a time in a distant galaxy, there lived",
|
|
"The old man walked slowly down the street, remembering",
|
|
"In the year 2157, humanity finally discovered",
|
|
# Technical/code style
|
|
"To implement a binary search tree in Python, first we need to",
|
|
"The algorithm works by iterating through the array and",
|
|
"Here's how to optimize database queries using indexing:",
|
|
# Factual/informative style
|
|
"The Renaissance was a period in European history that",
|
|
"Climate change is caused by several factors including",
|
|
"The human brain contains approximately 86 billion neurons which",
|
|
# Conversational style
|
|
"I've been thinking about getting a new laptop because",
|
|
"Yesterday I went to the store and bought",
|
|
"My favorite thing about summer is definitely",
|
|
]
|
|
|
|
# Pick a random template
|
|
base_prompt = random.choice(prompt_templates)
|
|
|
|
if max_words < min_words:
|
|
max_words = min_words
|
|
target_words = random.randint(min_words, max_words)
|
|
|
|
if target_words > 50:
|
|
# For longer prompts, repeat context
|
|
padding_text = (
|
|
" This is an interesting topic that deserves more explanation. "
|
|
* (target_words // 50)
|
|
)
|
|
base_prompt = base_prompt + padding_text
|
|
|
|
return base_prompt
|
|
|
|
|
|
@hopper_only
|
|
@pytest.mark.timeout(1000)
|
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|
"""
|
|
Ensures that the same request (the 'needle' prompt) yields identical output
|
|
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
|
using the high-level v1 LLM() API only (no manual batching).
|
|
|
|
Strategy:
|
|
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
|
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
|
- For many trials, generate a batch (size N) where the needle appears at a
|
|
random position among random filler prompts using the bs=N engine.
|
|
- Track how many trials match vs mismatch, and report totals at the end.
|
|
The test fails if any mismatches occur, but we still dump pass/fail
|
|
counts.
|
|
|
|
Notes:
|
|
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
|
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
|
to produce a more random-sounding phrase, yet remain deterministic by
|
|
seed.
|
|
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
|
"""
|
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
|
random.seed(seed)
|
|
|
|
# Allow overrides from environment (useful for CI tuning)
|
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
|
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
|
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
|
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
|
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
|
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
|
|
|
# Keep GPU memory usage low to avoid startup allocation failures.
|
|
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
|
|
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
|
|
|
# Sampling parameters: longer outputs with a more random-sounding
|
|
# continuation,but still deterministic due to fixed seed.
|
|
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
|
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
|
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
|
|
|
|
sampling = SamplingParams(
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens,
|
|
seed=20240919,
|
|
)
|
|
|
|
needle_prompt = "There once was a "
|
|
|
|
llm_bs1 = None
|
|
llm_bsN = None
|
|
try:
|
|
# Engine with bs=1 behavior
|
|
llm_bs1 = LLM_with_max_seqs(
|
|
model=model,
|
|
max_num_seqs=max_batch_size,
|
|
gpu_memory_utilization=gpu_mem_util,
|
|
max_model_len=max_model_len,
|
|
)
|
|
|
|
# Baseline generation for the needle prompt alone.
|
|
baseline_out = llm_bs1.generate([needle_prompt], sampling)
|
|
assert len(baseline_out) == 1
|
|
assert len(baseline_out[0].outputs) >= 1
|
|
baseline_text = baseline_out[0].outputs[0].text
|
|
|
|
# Engine with larger batch limit (e.g., 64)
|
|
llm_bsN = LLM_with_max_seqs(
|
|
model=model,
|
|
max_num_seqs=max_batch_size,
|
|
gpu_memory_utilization=gpu_mem_util,
|
|
max_model_len=max_model_len,
|
|
)
|
|
|
|
mismatches = 0
|
|
|
|
for trial in range(num_trials):
|
|
# Create a batch of size `max_batch_size` and insert the needle at
|
|
# a random index
|
|
prompts: list[str] = []
|
|
batch_size = random.randint(max_batch_size // 2, max_batch_size)
|
|
needle_pos = random.randint(0, batch_size - 1)
|
|
for i in range(batch_size):
|
|
if i == needle_pos:
|
|
prompts.append(needle_prompt)
|
|
else:
|
|
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
|
|
|
|
# Generate with the larger-batch engine
|
|
outputs = llm_bsN.generate(prompts, sampling)
|
|
# Find the needle output by position
|
|
needle_output = outputs[needle_pos]
|
|
assert needle_output.prompt == needle_prompt
|
|
assert len(needle_output.outputs) >= 1
|
|
text = needle_output.outputs[0].text
|
|
|
|
if text != baseline_text:
|
|
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
|
mismatches += 1
|
|
|
|
passes = num_trials - mismatches
|
|
# Dump how many passed vs failed
|
|
print(
|
|
f"[determinism] total={num_trials}, passed={passes}, "
|
|
f"failed={mismatches}, max_batch_size={max_batch_size}"
|
|
)
|
|
|
|
if mismatches > 0:
|
|
pytest.fail(
|
|
f"Nondeterministic outputs detected: {mismatches} failed out "
|
|
f"of {num_trials} trials (max_batch_size={max_batch_size})."
|
|
)
|
|
|
|
finally:
|
|
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
|
if llm_bs1 is not None:
|
|
with contextlib.suppress(Exception):
|
|
llm_bs1.shutdown()
|
|
if llm_bsN is not None:
|
|
with contextlib.suppress(Exception):
|
|
llm_bsN.shutdown()
|
|
|
|
|
|
def _extract_step_logprobs(request_output):
|
|
if getattr(request_output, "outputs", None):
|
|
inner = request_output.outputs[0]
|
|
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
|
t = torch.tensor(
|
|
[
|
|
inner.logprobs[i][tid].logprob
|
|
for i, tid in enumerate(inner.token_ids)
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
return t, inner.token_ids
|
|
|
|
return None, None
|
|
|
|
|
|
@hopper_only
|
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
|
@pytest.mark.forked
|
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
|
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
|
random.seed(seed)
|
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
|
|
|
# For batch invariance, disable custom all-reduce to ensure deterministic
|
|
# all-reduce operations (custom all-reduce may not be deterministic)
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
|
|
disable_custom_ar = vllm_is_batch_invariant()
|
|
|
|
if disable_custom_ar:
|
|
print(f"\n{'=' * 80}")
|
|
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
llm = LLM(
|
|
model=model_name,
|
|
tensor_parallel_size=tp_size,
|
|
enable_prefix_caching=False,
|
|
max_num_seqs=32,
|
|
max_model_len=8192,
|
|
dtype="bfloat16", # not everything is supported
|
|
)
|
|
|
|
# Use more realistic prompts for better token generation
|
|
prompts = [_random_prompt(10, 50) for i in range(32)]
|
|
|
|
sp = SamplingParams(
|
|
temperature=0.6,
|
|
top_p=1.0,
|
|
max_tokens=8,
|
|
seed=1234,
|
|
logprobs=5,
|
|
)
|
|
|
|
# BS=1: run prompts individually and collect logprobs per step.
|
|
print("\n" + "=" * 80)
|
|
print("STARTING BS=1 RUNS (each prompt individually)")
|
|
print("=" * 80 + "\n")
|
|
|
|
bs1_logprobs_per_prompt = []
|
|
bs1_tokens_per_prompt = []
|
|
for idx, p in enumerate(prompts):
|
|
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
|
outs = llm.generate([p], sp, use_tqdm=False)
|
|
assert len(outs) == 1
|
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
|
if step_logprobs is None:
|
|
pytest.skip(
|
|
"Logits are not available on RequestOutput; "
|
|
"enable logprobs return to run this test."
|
|
)
|
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
|
bs1_tokens_per_prompt.append(token_ids)
|
|
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
|
|
|
# BS=N: run prompts in a batch and collect logprobs per step for each
|
|
# prompt.
|
|
print("\n" + "=" * 80)
|
|
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
|
print("=" * 80 + "\n")
|
|
|
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
|
assert len(outs_batched) == len(prompts)
|
|
bsN_logprobs_per_prompt = []
|
|
bsN_tokens_per_prompt = []
|
|
|
|
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
|
for idx, o in enumerate(outs_batched):
|
|
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
|
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
|
step_logprobs, token_ids = _extract_step_logprobs(o)
|
|
if step_logprobs is None:
|
|
pytest.skip(
|
|
"Logits are not available on RequestOutput; "
|
|
"enable logprobs return to run this test."
|
|
)
|
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
|
bsN_tokens_per_prompt.append(token_ids)
|
|
|
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
|
failed_prompts = []
|
|
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
|
zip(
|
|
bs1_logprobs_per_prompt,
|
|
bsN_logprobs_per_prompt,
|
|
bs1_tokens_per_prompt,
|
|
bsN_tokens_per_prompt,
|
|
)
|
|
):
|
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
|
reason = (
|
|
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
|
f"vs {len(logprobs_bsN)} (BS=N)"
|
|
)
|
|
failed_prompts.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": "all",
|
|
"reason": reason,
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Check if tokens match first
|
|
if tokens_bs1 != tokens_bsN:
|
|
failed_prompts.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": "sampling",
|
|
"reason": "Different tokens sampled",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
"bs1_all_logprobs": [
|
|
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
|
],
|
|
"bsN_all_logprobs": [
|
|
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
|
],
|
|
}
|
|
)
|
|
continue
|
|
|
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
|
if a.shape != b.shape:
|
|
failed_prompts.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": t,
|
|
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
break
|
|
|
|
if not torch.equal(a, b):
|
|
max_diff = torch.abs(a - b).max().item()
|
|
# Print which token failed
|
|
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
|
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
|
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
|
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
|
print(f" BS=1 logprob: {a.tolist()}")
|
|
print(f" BS=N logprob: {b.tolist()}")
|
|
failed_prompts.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": t,
|
|
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
"bs1_all_logprobs": [
|
|
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
|
],
|
|
"bsN_all_logprobs": [
|
|
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
|
],
|
|
}
|
|
)
|
|
break
|
|
|
|
# Print summary of all failures
|
|
if failed_prompts:
|
|
print(f"\n{'=' * 80}")
|
|
fail_msg = (
|
|
f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
|
f"{len(prompts)} prompts failed"
|
|
)
|
|
print(fail_msg)
|
|
print(f"{'=' * 80}")
|
|
for fail in failed_prompts:
|
|
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
|
|
print(f" Reason: {fail['reason']}")
|
|
print(f" Preview: {fail['prompt_preview']}...")
|
|
|
|
# Always show the tokens
|
|
if "bs1_tokens" in fail:
|
|
print(f" BS=1 tokens: {fail['bs1_tokens']}")
|
|
if "bsN_tokens" in fail:
|
|
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
|
|
|
if "bs1_all_logprobs" in fail:
|
|
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
|
|
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
|
print(f" Step {step_idx}: {logprobs}")
|
|
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
|
|
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
|
print(f" Step {step_idx}: {logprobs}")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
# Fail the test with summary
|
|
msg = (
|
|
f"Batch invariance violated in {len(failed_prompts)}/"
|
|
f"{len(prompts)} prompts. See output above for details."
|
|
)
|
|
pytest.fail(msg)
|
|
|
|
|
|
@hopper_only
|
|
def test_simple_generation():
|
|
"""
|
|
Simple test that runs the model with a basic prompt and prints the output.
|
|
Useful for quick smoke testing and debugging.
|
|
"""
|
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
|
|
|
llm = LLM(
|
|
model=model,
|
|
max_num_seqs=1,
|
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
|
enforce_eager=True,
|
|
gpu_memory_utilization=0.9,
|
|
max_model_len=2048,
|
|
dtype="bfloat16",
|
|
enable_prefix_caching=False,
|
|
)
|
|
|
|
prompt = "the capital of france is"
|
|
sampling_params = SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=20,
|
|
)
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print("Running simple generation test")
|
|
print(f"Prompt: '{prompt}'")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
try:
|
|
outputs = llm.generate([prompt], sampling_params)
|
|
|
|
assert len(outputs) == 1
|
|
output_text = outputs[0].outputs[0].text
|
|
|
|
print(f"Output: '{output_text}'")
|
|
print(f"\n{'=' * 80}")
|
|
print(f"Full completion: '{prompt}{output_text}'")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
finally:
|
|
with contextlib.suppress(Exception):
|
|
llm.shutdown()
|
|
|
|
|
|
@hopper_only
|
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
|
@pytest.mark.forked
|
|
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|
"""
|
|
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
|
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
|
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
|
|
doing something useful.
|
|
|
|
The test will PASS if we detect differences (proving batch invariance matters).
|
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
|
"""
|
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
|
|
# CRITICAL: Disable batch invariance for this test
|
|
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
|
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
|
|
|
try:
|
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
|
random.seed(seed)
|
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
llm = LLM(
|
|
model=model_name,
|
|
tensor_parallel_size=tp_size,
|
|
enable_prefix_caching=False,
|
|
max_num_seqs=32,
|
|
max_model_len=8192,
|
|
dtype="bfloat16",
|
|
)
|
|
|
|
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
|
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
|
|
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
|
prompts: list[str] = []
|
|
options = [
|
|
(max(long_min, 1536), max(long_max, 3072)), # very long
|
|
(max(1024, long_min), max(2048, long_max)), # long
|
|
(256, 512), # mid
|
|
(10, 20), # short
|
|
]
|
|
|
|
for _ in range(32):
|
|
lo, hi = random.choice(options)
|
|
prompts.append(_random_prompt(lo, hi))
|
|
|
|
sp = SamplingParams(
|
|
temperature=0.6,
|
|
top_p=1.0,
|
|
max_tokens=8,
|
|
seed=1234,
|
|
logprobs=5,
|
|
)
|
|
|
|
# BS=1: run prompts individually and collect logprobs per step.
|
|
print("\n" + "=" * 80)
|
|
print("STARTING BS=1 RUNS (each prompt individually)")
|
|
print("=" * 80 + "\n")
|
|
|
|
bs1_logprobs_per_prompt = []
|
|
bs1_tokens_per_prompt = []
|
|
for idx, p in enumerate(prompts):
|
|
print(
|
|
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
|
)
|
|
outs = llm.generate([p], sp, use_tqdm=False)
|
|
assert len(outs) == 1
|
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
|
if step_logprobs is None:
|
|
pytest.skip(
|
|
"Logits are not available on RequestOutput; "
|
|
"enable logprobs return to run this test."
|
|
)
|
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
|
bs1_tokens_per_prompt.append(token_ids)
|
|
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
|
|
|
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
|
print("\n" + "=" * 80)
|
|
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
|
print("=" * 80 + "\n")
|
|
|
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
|
assert len(outs_batched) == len(prompts)
|
|
bsN_logprobs_per_prompt = []
|
|
bsN_tokens_per_prompt = []
|
|
|
|
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
|
for idx, o in enumerate(outs_batched):
|
|
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
|
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
|
step_logprobs, token_ids = _extract_step_logprobs(o)
|
|
if step_logprobs is None:
|
|
pytest.skip(
|
|
"Logits are not available on RequestOutput; "
|
|
"enable logprobs return to run this test."
|
|
)
|
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
|
bsN_tokens_per_prompt.append(token_ids)
|
|
|
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
|
differences_found = []
|
|
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
|
zip(
|
|
bs1_logprobs_per_prompt,
|
|
bsN_logprobs_per_prompt,
|
|
bs1_tokens_per_prompt,
|
|
bsN_tokens_per_prompt,
|
|
)
|
|
):
|
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
|
reason = (
|
|
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
|
f"vs {len(logprobs_bsN)} (BS=N)"
|
|
)
|
|
differences_found.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": "all",
|
|
"reason": reason,
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Check if tokens match first
|
|
if tokens_bs1 != tokens_bsN:
|
|
differences_found.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": "sampling",
|
|
"reason": "Different tokens sampled",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
continue
|
|
|
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
|
if a.shape != b.shape:
|
|
differences_found.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": t,
|
|
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
break
|
|
|
|
if not torch.equal(a, b):
|
|
max_diff = torch.abs(a - b).max().item()
|
|
print(
|
|
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
|
f"Token {t}: max_diff={max_diff:.6e}"
|
|
)
|
|
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
|
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
|
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
|
print(f" BS=1 logprob: {a.tolist()}")
|
|
print(f" BS=N logprob: {b.tolist()}")
|
|
differences_found.append(
|
|
{
|
|
"prompt_idx": i,
|
|
"step": t,
|
|
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
|
"prompt_preview": prompts[i][:100],
|
|
"bs1_tokens": tokens_bs1,
|
|
"bsN_tokens": tokens_bsN,
|
|
}
|
|
)
|
|
break
|
|
|
|
# Print summary
|
|
print(f"\n{'=' * 80}")
|
|
if differences_found:
|
|
success_msg = (
|
|
f"✓ SUCCESS: Batch invariance is doing something! "
|
|
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
|
f"with differences when batch invariance was DISABLED."
|
|
)
|
|
print(success_msg)
|
|
print(f"{'=' * 80}")
|
|
for diff in differences_found:
|
|
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
|
print(f" Reason: {diff['reason']}")
|
|
print(f" Preview: {diff['prompt_preview']}...")
|
|
if "bs1_tokens" in diff:
|
|
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
|
if "bsN_tokens" in diff:
|
|
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
|
print(f"{'=' * 80}\n")
|
|
# Test PASSES because we found differences (batch invariance matters!)
|
|
return
|
|
else:
|
|
# Test FAILS because everything matched even without batch invariance
|
|
fail_msg = (
|
|
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
|
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
|
f"This suggests batch invariance might not be necessary, "
|
|
f"or the test needs more sensitive prompts."
|
|
)
|
|
print(fail_msg)
|
|
print(f"{'=' * 80}\n")
|
|
pytest.fail(fail_msg)
|
|
|
|
finally:
|
|
# Restore original value
|
|
if old_value is None:
|
|
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
|
else:
|
|
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
|
|
|
|
|
@hopper_only
|
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
|
@pytest.mark.forked
|
|
def test_decode_logprobs_match_prefill_logprobs(backend):
|
|
"""
|
|
Test that verifies decode logprobs match prefill logprobs.
|
|
|
|
For each decoded token at position i:
|
|
1. Run decode to generate N tokens and collect their logprobs
|
|
2. For each position i in [0, N):
|
|
- Take prefix = prompt + tokens[0:i]
|
|
- Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
|
|
- Verify prefill logprob matches decode logprob bitwise
|
|
|
|
This ensures that the logprobs from decode are consistent with what
|
|
we would get if we ran prefill on each prefix.
|
|
"""
|
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
|
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
|
random.seed(seed)
|
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
|
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
|
|
disable_custom_ar = vllm_is_batch_invariant()
|
|
|
|
if disable_custom_ar:
|
|
print(f"\n{'=' * 80}")
|
|
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
llm = LLM(
|
|
model=model_name,
|
|
tensor_parallel_size=tp_size,
|
|
enable_prefix_caching=False,
|
|
max_num_seqs=32,
|
|
max_model_len=8192,
|
|
dtype="bfloat16",
|
|
)
|
|
|
|
# Use a few test prompts
|
|
num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
|
|
prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]
|
|
|
|
# Generate longer sequences to test multiple decode steps
|
|
max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))
|
|
|
|
sp = SamplingParams(
|
|
temperature=0.0, # Greedy for determinism
|
|
max_tokens=max_tokens,
|
|
logprobs=5,
|
|
)
|
|
|
|
print("\n" + "=" * 80)
|
|
print("STEP 1: Running decode to generate tokens and collect logprobs")
|
|
print("=" * 80 + "\n")
|
|
|
|
# Step 1: Run decode and collect logprobs
|
|
decode_outputs = llm.generate(prompts, sp, use_tqdm=False)
|
|
|
|
failed_comparisons = []
|
|
|
|
for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
|
|
print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")
|
|
|
|
# Extract decode logprobs and tokens
|
|
decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
|
|
if decode_logprobs is None:
|
|
pytest.skip(
|
|
"Logprobs are not available on RequestOutput; "
|
|
"enable logprobs return to run this test."
|
|
)
|
|
|
|
print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
|
|
print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")
|
|
|
|
# Step 2: For each token position, run prefill and compare
|
|
print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")
|
|
|
|
for token_idx in range(len(token_ids)):
|
|
# Construct the prefix up to (but not including) this token
|
|
current_token = token_ids[token_idx]
|
|
|
|
# We need to detokenize to get the text prefix
|
|
# For this, we'll use the tokenizer from the LLM
|
|
# However, the LLM API doesn't expose tokenizer easily, so we'll
|
|
# construct the prefix by decoding from the original prompt
|
|
|
|
# Get text up to this point by using the output text
|
|
# This is approximate but should work for verification
|
|
if token_idx == 0:
|
|
prefix_prompt = prompt
|
|
else:
|
|
# Use the partial output text up to this token
|
|
# We'll need to construct this from the full output
|
|
prefix_output = decode_output.outputs[0]
|
|
# Get the text for tokens 0 to token_idx-1
|
|
# Unfortunately, we don't have per-token text, so we'll use
|
|
# a different approach: run prefill with prompt + tokens[0:token_idx]
|
|
|
|
# Actually, we need to get the actual text. Let's use a workaround:
|
|
# Run a generation with max_tokens = token_idx to get that prefix
|
|
prefix_sp = SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=token_idx,
|
|
logprobs=1,
|
|
)
|
|
prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
|
|
prefix_prompt = prompt + prefix_output.outputs[0].text
|
|
|
|
# Now run prefill with max_tokens=1 to get the logprob of the next token
|
|
prefill_sp = SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=1,
|
|
logprobs=5,
|
|
)
|
|
|
|
print(
|
|
f" [Token {token_idx}] Running prefill for prefix "
|
|
f"(len={len(prefix_prompt)})..."
|
|
)
|
|
prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
|
|
0
|
|
]
|
|
prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)
|
|
|
|
if prefill_logprobs is None:
|
|
print(f" [Token {token_idx}] Warning: No prefill logprobs available")
|
|
continue
|
|
|
|
# The first token from prefill should match the current token
|
|
prefill_token = prefill_token_ids[0]
|
|
prefill_logprob = prefill_logprobs[0].item()
|
|
decode_logprob = decode_logprobs[token_idx].item()
|
|
|
|
print(
|
|
f" [Token {token_idx}] Decode token: {current_token}, "
|
|
f"logprob: {decode_logprob:.8f}"
|
|
)
|
|
print(
|
|
f" [Token {token_idx}] Prefill token: {prefill_token}, "
|
|
f"logprob: {prefill_logprob:.8f}"
|
|
)
|
|
|
|
# Check if tokens match
|
|
if current_token != prefill_token:
|
|
failed_comparisons.append(
|
|
{
|
|
"prompt_idx": prompt_idx,
|
|
"token_idx": token_idx,
|
|
"reason": "Token mismatch",
|
|
"decode_token": current_token,
|
|
"prefill_token": prefill_token,
|
|
"decode_logprob": decode_logprob,
|
|
"prefill_logprob": prefill_logprob,
|
|
"prompt_text": prompt[:100],
|
|
"prefix_text": prefix_prompt[:100],
|
|
}
|
|
)
|
|
print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!")
|
|
continue
|
|
|
|
# Check if logprobs match bitwise
|
|
if decode_logprob != prefill_logprob:
|
|
diff = abs(decode_logprob - prefill_logprob)
|
|
failed_comparisons.append(
|
|
{
|
|
"prompt_idx": prompt_idx,
|
|
"token_idx": token_idx,
|
|
"reason": "Logprob mismatch",
|
|
"decode_token": current_token,
|
|
"prefill_token": prefill_token,
|
|
"decode_logprob": decode_logprob,
|
|
"prefill_logprob": prefill_logprob,
|
|
"diff": diff,
|
|
"prompt_text": prompt[:100],
|
|
"prefix_text": prefix_prompt[:100],
|
|
"decode_all_tokens": token_ids,
|
|
"decode_all_logprobs": decode_logprobs.tolist(),
|
|
}
|
|
)
|
|
print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
|
|
else:
|
|
print(f" [Token {token_idx}] ✓ Match (bitwise equal)")
|
|
|
|
# Print summary
|
|
print(f"\n{'=' * 80}")
|
|
if failed_comparisons:
|
|
print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
|
|
print(f"{'=' * 80}")
|
|
|
|
# Group failures by prompt for better readability
|
|
failures_by_prompt: dict[int, list[dict]] = {}
|
|
for fail in failed_comparisons:
|
|
pid = fail["prompt_idx"]
|
|
if pid not in failures_by_prompt:
|
|
failures_by_prompt[pid] = []
|
|
failures_by_prompt[pid].append(fail)
|
|
|
|
for prompt_idx, failures in failures_by_prompt.items():
|
|
print(f"\n{'=' * 80}")
|
|
print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
|
|
print(f"{'=' * 80}")
|
|
print(f"Total failures for this prompt: {len(failures)}")
|
|
|
|
# Show where mismatches occur (which token positions)
|
|
mismatch_positions = [f["token_idx"] for f in failures]
|
|
print(f"Mismatch at token positions: {mismatch_positions}")
|
|
|
|
# Show first few failures in detail
|
|
for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt
|
|
print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:")
|
|
print(f" Reason: {fail['reason']}")
|
|
print(f" Prefix text: '{fail['prefix_text']}...'")
|
|
print(
|
|
f" Decode: token={fail['decode_token']}, "
|
|
f"logprob={fail['decode_logprob']:.10f}"
|
|
)
|
|
print(
|
|
f" Prefill: token={fail['prefill_token']}, "
|
|
f"logprob={fail['prefill_logprob']:.10f}"
|
|
)
|
|
if "diff" in fail:
|
|
print(f" Difference: {fail['diff']:.10e}")
|
|
# Show in hex to see bitwise difference
|
|
import struct
|
|
|
|
decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
|
|
prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
|
|
print(f" Decode logprob (hex): 0x{decode_hex}")
|
|
print(f" Prefill logprob (hex): 0x{prefill_hex}")
|
|
|
|
# If we have all tokens/logprobs, show the context
|
|
if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
|
|
token_idx = fail["token_idx"]
|
|
all_tokens = fail["decode_all_tokens"]
|
|
all_logprobs = fail["decode_all_logprobs"]
|
|
|
|
# Show context: 2 tokens before and after
|
|
start = max(0, token_idx - 2)
|
|
end = min(len(all_tokens), token_idx + 3)
|
|
|
|
print(f" Context (tokens {start} to {end - 1}):")
|
|
for j in range(start, end):
|
|
marker = " <-- MISMATCH" if j == token_idx else ""
|
|
print(
|
|
f" [{j}] token={all_tokens[j]}, "
|
|
f"logprob={all_logprobs[j]:.8f}{marker}"
|
|
)
|
|
|
|
if len(failures) > 5:
|
|
print(f"\n ... and {len(failures) - 5} more failures for this prompt")
|
|
|
|
print(f"\n{'=' * 80}\n")
|
|
|
|
pytest.fail(
|
|
f"Decode logprobs do not match prefill logprobs: "
|
|
f"{len(failed_comparisons)} mismatches found."
|
|
)
|
|
else:
|
|
print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
|
|
def LLM_with_max_seqs(
|
|
model: str,
|
|
max_num_seqs: int,
|
|
gpu_memory_utilization: float,
|
|
max_model_len: int,
|
|
) -> LLM:
|
|
"""
|
|
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
|
using the high-level v1 LLM API, while constraining memory usage.
|
|
"""
|
|
return LLM(
|
|
model=model,
|
|
max_num_seqs=max_num_seqs,
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
max_model_len=max_model_len,
|
|
dtype="bfloat16",
|
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
|
enable_prefix_caching=False,
|
|
enforce_eager=True,
|
|
# Enable for MOE models
|
|
# enable_expert_parallel=True,
|
|
)
|