mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] Batch Invariant: Support DeepGEMM and Blackwell (#27127)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -10,9 +10,9 @@ import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
hopper_only = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
|
||||
reason="Requires CUDA and Hopper (SM90)",
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
return base_prompt
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output):
|
||||
return None, None
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
def test_simple_generation():
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
@ -480,7 +480,7 @@ def test_simple_generation():
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.forked
|
||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
|
@ -14,13 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
hopper_only = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.is_device_capability(90)),
|
||||
reason="Requires CUDA and Hopper (SM90)",
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||
@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
def test_rms_norm_numerical_stability():
|
||||
"""
|
||||
Test RMS norm numerical stability with extreme values.
|
||||
@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability():
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
def test_rms_norm_formula():
|
||||
"""
|
||||
Test that RMS norm follows the correct mathematical formula.
|
||||
@ -204,7 +204,7 @@ def test_rms_norm_formula():
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
"""
|
||||
@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
)
|
||||
|
||||
|
||||
@hopper_only
|
||||
@skip_unsupported
|
||||
def test_rms_norm_determinism():
|
||||
"""
|
||||
Test that batch-invariant RMS norm produces deterministic results.
|
||||
|
@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
@ -94,9 +95,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
@ -539,8 +542,34 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
|
||||
# we will use BF16 dequant when DeepGEMM is not supported.
|
||||
if vllm_is_batch_invariant():
|
||||
if self.block_quant and should_use_deepgemm_for_fp8_linear(
|
||||
torch.bfloat16, layer.weight, None
|
||||
):
|
||||
# use group quant consistent with block size across K
|
||||
assert self.act_q_group_shape is not None
|
||||
q_input, input_scale = QuantFP8(
|
||||
False,
|
||||
self.act_q_group_shape,
|
||||
column_major_scales=True,
|
||||
)(x)
|
||||
|
||||
output_2d = torch.empty(
|
||||
(q_input.shape[0], layer.weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(layer.weight, layer.weight_scale),
|
||||
output_2d,
|
||||
)
|
||||
if bias is not None:
|
||||
output_2d = output_2d + bias
|
||||
return output_2d
|
||||
|
||||
# Dequantize FP8 weights to BF16
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
@ -555,9 +584,30 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
N, K = weight_fp8.shape
|
||||
|
||||
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
|
||||
# We need to transpose it to [num_blocks_n, num_blocks_k] first
|
||||
weight_scale = weight_scale.t()
|
||||
# determine expected number of blocks along N and K
|
||||
num_blocks_n = (N + block_n - 1) // block_n
|
||||
num_blocks_k = (K + block_k - 1) // block_k
|
||||
|
||||
# scale layout may be [num_blocks_n, num_blocks_k]
|
||||
# or [num_blocks_k, num_blocks_n] depending on backend
|
||||
if weight_scale.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}"
|
||||
)
|
||||
|
||||
scale_rows, scale_cols = weight_scale.shape
|
||||
if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n):
|
||||
if num_blocks_n == num_blocks_k:
|
||||
# ambiguous square case, warn and skip transpose
|
||||
logger.warning(
|
||||
"Batch-invariant FP8: square block-scale %dx%d; "
|
||||
"skipping transpose to avoid misorientation.",
|
||||
scale_rows,
|
||||
scale_cols,
|
||||
)
|
||||
else:
|
||||
# clear KN -> transpose to NK
|
||||
weight_scale = weight_scale.t()
|
||||
|
||||
# Expand scale to match weight dimensions
|
||||
# scale_expanded should have shape [N, K]
|
||||
|
Reference in New Issue
Block a user