mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Luka Govedic <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import itertools
|
|
from typing import Callable
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
from vllm.triton_utils import triton
|
|
|
|
|
|
# TODO(luka): use standalone_compile utility
|
|
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
|
def inner(*args):
|
|
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
|
return fn(*args)
|
|
|
|
return inner
|
|
|
|
|
|
torch._dynamo.config.recompile_limit = 8888
|
|
compilation_config = CompilationConfig(custom_ops=["none"])
|
|
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
|
torch_per_token_quant_fp8 = torch.compile(
|
|
QuantFP8(False, GroupShape.PER_TOKEN),
|
|
fullgraph=True,
|
|
dynamic=False, # recompile for different shapes
|
|
)
|
|
|
|
# First dim is explicitly dynamic to simulate vLLM usage
|
|
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
|
|
|
|
|
def cuda_per_token_quant_fp8(
|
|
input: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return ops.scaled_fp8_quant(input)
|
|
|
|
|
|
def calculate_diff(batch_size: int, seq_len: int):
|
|
"""Calculate difference between Triton and CUDA implementations."""
|
|
device = torch.device("cuda")
|
|
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
|
|
|
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
|
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
|
|
|
if torch.allclose(
|
|
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
|
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
|
print("✅ All implementations match")
|
|
else:
|
|
print("❌ Implementations differ")
|
|
|
|
|
|
batch_size_range = [1, 16, 32, 64, 128]
|
|
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
|
|
|
configs = list(itertools.product(batch_size_range, seq_len_range))
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size", "seq_len"],
|
|
x_vals=configs,
|
|
line_arg="provider",
|
|
line_vals=["torch", "cuda"],
|
|
line_names=["Torch", "CUDA"],
|
|
styles=[("blue", "-"), ("green", "-")],
|
|
ylabel="us",
|
|
plot_name="per-token-dynamic-quant-fp8-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark_quantization(batch_size, seq_len, provider):
|
|
dtype = torch.float16
|
|
device = torch.device("cuda")
|
|
|
|
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "torch":
|
|
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
|
elif provider == "cuda":
|
|
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
calculate_diff(batch_size=4, seq_len=4096)
|
|
benchmark_quantization.run(print_data=True)
|