mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: ca1207 <ca1207zzz@gmail.com> Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com> Co-authored-by: WyldeCat <skan1543@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
156 lines
3.9 KiB
Python
156 lines
3.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import itertools
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as vllm_ops
|
|
from vllm.triton_utils import triton
|
|
|
|
|
|
def polynorm_naive(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
):
|
|
orig_shape = x.shape
|
|
x = x.view(-1, x.shape[-1])
|
|
|
|
def norm(x, eps: float):
|
|
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
|
|
|
x = x.float()
|
|
return (
|
|
(
|
|
weight[0] * norm(x**3, eps)
|
|
+ weight[1] * norm(x**2, eps)
|
|
+ weight[2] * norm(x, eps)
|
|
+ bias
|
|
)
|
|
.to(weight.dtype)
|
|
.view(orig_shape)
|
|
)
|
|
|
|
|
|
def polynorm_vllm(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
):
|
|
orig_shape = x.shape
|
|
x = x.view(-1, x.shape[-1])
|
|
|
|
out = torch.empty_like(x)
|
|
vllm_ops.poly_norm(out, x, weight, bias, eps)
|
|
output = out
|
|
|
|
output = output.view(orig_shape)
|
|
return output
|
|
|
|
|
|
def calculate_diff(batch_size, seq_len, hidden_dim):
|
|
dtype = torch.bfloat16
|
|
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
|
weight = torch.ones(3, dtype=dtype, device="cuda")
|
|
bias = torch.ones(1, dtype=dtype, device="cuda")
|
|
|
|
output_naive = polynorm_naive(x, weight, bias)
|
|
output_vllm = polynorm_vllm(x, weight, bias)
|
|
|
|
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
|
print("✅ All implementations match")
|
|
else:
|
|
print("❌ Implementations differ")
|
|
|
|
|
|
batch_size_range = [2**i for i in range(0, 7, 2)]
|
|
seq_length_range = [2**i for i in range(6, 11, 1)]
|
|
dim_range = [2048, 4096]
|
|
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
|
|
|
|
|
|
def get_benchmark():
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["dim", "batch_size", "seq_len"],
|
|
x_vals=[list(_) for _ in configs],
|
|
line_arg="provider",
|
|
line_vals=["naive", "vllm"],
|
|
line_names=["Naive", "vLLM"],
|
|
styles=[("blue", "-"), ("red", "-")],
|
|
ylabel="us",
|
|
plot_name="polynorm-perf",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(dim, batch_size, seq_len, provider):
|
|
dtype = torch.bfloat16
|
|
hidden_dim = dim * 4
|
|
|
|
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
|
weight = torch.ones(3, dtype=dtype, device="cuda")
|
|
bias = torch.ones(1, dtype=dtype, device="cuda")
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "naive":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: polynorm_naive(x, weight, bias),
|
|
quantiles=quantiles,
|
|
)
|
|
else:
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: polynorm_vllm(x, weight, bias),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
return benchmark
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=4,
|
|
help="Batch size",
|
|
)
|
|
parser.add_argument(
|
|
"--seq-len",
|
|
type=int,
|
|
default=128,
|
|
help="Sequence length",
|
|
)
|
|
parser.add_argument(
|
|
"--hidden-dim",
|
|
type=int,
|
|
default=8192,
|
|
help="Intermediate size of MLP",
|
|
)
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
default="./configs/polnorm/",
|
|
help="Path to save polnorm benchmark results",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Run correctness test
|
|
calculate_diff(
|
|
batch_size=args.batch_size,
|
|
seq_len=args.seq_len,
|
|
hidden_dim=args.hidden_dim,
|
|
)
|
|
|
|
benchmark = get_benchmark()
|
|
# Run performance benchmark
|
|
benchmark.run(print_data=True, save_path=args.save_path)
|