Files
vllm/benchmarks/kernels/benchmark_activation.py
Jiangyun Zhu 77aec83b8c [Benchmark] add benchmark for custom activation op (#23908)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-09-06 20:12:05 -07:00

105 lines
3.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# benchmark custom activation op performance
import itertools
import torch
import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
intermediate_size = [3072, 9728, 12288]
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
def benchmark_activation(
batch_size: int,
seq_len: int,
intermediate_size: int,
provider: str,
func_name: str,
dtype: torch.dtype,
):
device = "cuda"
num_tokens = batch_size * seq_len
dim = intermediate_size
current_platform.seed_everything(42)
torch.set_default_device(device)
if func_name == "gelu_and_mul":
layer = CustomOp.op_registry[func_name](approximate="none")
elif func_name == "gelu_and_mul_tanh":
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
elif func_name == "fatrelu_and_mul":
threshold = 0.5
layer = CustomOp.op_registry[func_name](threshold)
else:
layer = CustomOp.op_registry[func_name]()
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
compiled_layer = torch.compile(layer.forward_native)
if provider == "custom":
fn = lambda: layer(x)
elif provider == "compiled":
fn = lambda: compiled_layer(x)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return ms, max_ms, min_ms
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the custom activation op.")
parser.add_argument(
"--func-name",
type=str,
choices=[
"mul_and_silu",
"silu_and_mul",
"gelu_and_mul",
"gelu_and_mul_tanh",
"fatrelu_and_mul",
"swigluoai_and_mul",
"gelu_new",
"gelu_fast",
"quick_gelu",
],
default="silu_and_mul",
)
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
)
args = parser.parse_args()
assert args
func_name = args.func_name
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
perf_report = triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "intermediate_size"],
x_vals=configs,
line_arg="provider",
line_vals=["custom", "compiled"],
line_names=["Custom OP", "Compiled"],
styles=[("blue", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"{func_name}-op-performance",
args={},
)
)
perf_report(
lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation(
batch_size, seq_len, intermediate_size, provider, func_name, dtype
)
).run(print_data=True)