mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			257 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			257 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # SPDX-License-Identifier: Apache-2.0
 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | |
| 
 | |
| import itertools
 | |
| from typing import Optional, Union
 | |
| 
 | |
| import torch
 | |
| from flashinfer.norm import fused_add_rmsnorm, rmsnorm
 | |
| from torch import nn
 | |
| 
 | |
| from vllm import _custom_ops as vllm_ops
 | |
| from vllm.triton_utils import triton
 | |
| 
 | |
| 
 | |
| class HuggingFaceRMSNorm(nn.Module):
 | |
|     def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
 | |
|         super().__init__()
 | |
|         self.weight = nn.Parameter(torch.ones(hidden_size))
 | |
|         self.variance_epsilon = eps
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         residual: Optional[torch.Tensor] = None,
 | |
|     ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
 | |
|         orig_dtype = x.dtype
 | |
|         x = x.to(torch.float32)
 | |
|         if residual is not None:
 | |
|             x = x + residual.to(torch.float32)
 | |
|             residual = x.to(orig_dtype)
 | |
| 
 | |
|         variance = x.pow(2).mean(dim=-1, keepdim=True)
 | |
|         x = x * torch.rsqrt(variance + self.variance_epsilon)
 | |
|         x = x.to(orig_dtype) * self.weight
 | |
|         if residual is None:
 | |
|             return x
 | |
|         else:
 | |
|             return x, residual
 | |
| 
 | |
| 
 | |
| def rmsnorm_naive(
 | |
|     x: torch.Tensor,
 | |
|     weight: torch.Tensor,
 | |
|     residual: Optional[torch.Tensor] = None,
 | |
|     eps: float = 1e-6,
 | |
| ):
 | |
|     naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
 | |
|     naive_norm.weight = nn.Parameter(weight)
 | |
|     naive_norm = naive_norm.to(x.device)
 | |
| 
 | |
|     orig_shape = x.shape
 | |
|     x = x.view(-1, x.shape[-1])
 | |
|     if residual is not None:
 | |
|         residual = residual.view(-1, residual.shape[-1])
 | |
| 
 | |
|     output = naive_norm(x, residual)
 | |
| 
 | |
|     if isinstance(output, tuple):
 | |
|         output = (output[0].view(orig_shape), output[1].view(orig_shape))
 | |
|     else:
 | |
|         output = output.view(orig_shape)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def rmsnorm_flashinfer(
 | |
|     x: torch.Tensor,
 | |
|     weight: torch.Tensor,
 | |
|     residual: Optional[torch.Tensor] = None,
 | |
|     eps: float = 1e-6,
 | |
| ):
 | |
|     orig_shape = x.shape
 | |
|     x = x.view(-1, x.shape[-1])
 | |
|     if residual is not None:
 | |
|         residual = residual.view(-1, residual.shape[-1])
 | |
| 
 | |
|     if residual is not None:
 | |
|         fused_add_rmsnorm(x, residual, weight, eps)
 | |
|         output = (x, residual)
 | |
|     else:
 | |
|         output = rmsnorm(x, weight, eps)
 | |
| 
 | |
|     if isinstance(output, tuple):
 | |
|         output = (output[0].view(orig_shape), output[1].view(orig_shape))
 | |
|     else:
 | |
|         output = output.view(orig_shape)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def rmsnorm_vllm(
 | |
|     x: torch.Tensor,
 | |
|     weight: torch.Tensor,
 | |
|     residual: Optional[torch.Tensor] = None,
 | |
|     eps: float = 1e-6,
 | |
| ):
 | |
|     orig_shape = x.shape
 | |
|     x = x.view(-1, x.shape[-1])
 | |
|     if residual is not None:
 | |
|         residual = residual.view(-1, residual.shape[-1])
 | |
| 
 | |
|     if residual is not None:
 | |
|         vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
 | |
|         output = (x, residual)
 | |
|     else:
 | |
|         out = torch.empty_like(x)
 | |
|         vllm_ops.rms_norm(out, x, weight, eps)
 | |
|         output = out
 | |
| 
 | |
|     if isinstance(output, tuple):
 | |
|         output = (output[0].view(orig_shape), output[1].view(orig_shape))
 | |
|     else:
 | |
|         output = output.view(orig_shape)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
 | |
|     dtype = torch.bfloat16
 | |
|     x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
 | |
|     weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
 | |
|     residual = torch.randn_like(x) if use_residual else None
 | |
| 
 | |
|     output_naive = rmsnorm_naive(
 | |
|         x.clone(), weight, residual.clone() if residual is not None else None
 | |
|     )
 | |
|     output_flashinfer = rmsnorm_flashinfer(
 | |
|         x.clone(), weight, residual.clone() if residual is not None else None
 | |
|     )
 | |
|     output_vllm = rmsnorm_vllm(
 | |
|         x.clone(), weight, residual.clone() if residual is not None else None
 | |
|     )
 | |
| 
 | |
|     if use_residual:
 | |
|         output_naive = output_naive[0]
 | |
|         output_flashinfer = output_flashinfer[0]
 | |
|         output_vllm = output_vllm[0]
 | |
| 
 | |
|     print(f"Naive output={output_naive}")
 | |
|     print(f"FlashInfer output={output_flashinfer}")
 | |
|     print(f"vLLM output={output_vllm}")
 | |
| 
 | |
|     if torch.allclose(
 | |
|         output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
 | |
|     ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
 | |
|         print("✅ All implementations match")
 | |
|     else:
 | |
|         print("❌ Implementations differ")
 | |
| 
 | |
| 
 | |
| batch_size_range = [2**i for i in range(0, 7, 2)]
 | |
| seq_length_range = [2**i for i in range(6, 11, 1)]
 | |
| head_num_range = [32, 48]
 | |
| configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
 | |
| 
 | |
| 
 | |
| def get_benchmark(use_residual):
 | |
|     @triton.testing.perf_report(
 | |
|         triton.testing.Benchmark(
 | |
|             x_names=["head_num", "batch_size", "seq_len"],
 | |
|             x_vals=[list(_) for _ in configs],
 | |
|             line_arg="provider",
 | |
|             line_vals=["huggingface", "flashinfer", "vllm"],
 | |
|             line_names=["HuggingFace", "FlashInfer", "vLLM"],
 | |
|             styles=[("blue", "-"), ("green", "-"), ("red", "-")],
 | |
|             ylabel="us",
 | |
|             plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
 | |
|             args={},
 | |
|         )
 | |
|     )
 | |
|     def benchmark(head_num, batch_size, seq_len, provider):
 | |
|         dtype = torch.bfloat16
 | |
|         hidden_size = head_num * 128  # assuming head_dim = 128
 | |
| 
 | |
|         x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
 | |
|         weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
 | |
|         residual = torch.randn_like(x) if use_residual else None
 | |
| 
 | |
|         quantiles = [0.5, 0.2, 0.8]
 | |
| 
 | |
|         if provider == "huggingface":
 | |
|             ms, min_ms, max_ms = triton.testing.do_bench(
 | |
|                 lambda: rmsnorm_naive(
 | |
|                     x.clone(),
 | |
|                     weight,
 | |
|                     residual.clone() if residual is not None else None,
 | |
|                 ),
 | |
|                 quantiles=quantiles,
 | |
|             )
 | |
|         elif provider == "flashinfer":
 | |
|             ms, min_ms, max_ms = triton.testing.do_bench(
 | |
|                 lambda: rmsnorm_flashinfer(
 | |
|                     x.clone(),
 | |
|                     weight,
 | |
|                     residual.clone() if residual is not None else None,
 | |
|                 ),
 | |
|                 quantiles=quantiles,
 | |
|             )
 | |
|         else:
 | |
|             ms, min_ms, max_ms = triton.testing.do_bench(
 | |
|                 lambda: rmsnorm_vllm(
 | |
|                     x.clone(),
 | |
|                     weight,
 | |
|                     residual.clone() if residual is not None else None,
 | |
|                 ),
 | |
|                 quantiles=quantiles,
 | |
|             )
 | |
| 
 | |
|         return 1000 * ms, 1000 * max_ms, 1000 * min_ms
 | |
| 
 | |
|     return benchmark
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     import argparse
 | |
| 
 | |
|     parser = argparse.ArgumentParser()
 | |
|     parser.add_argument(
 | |
|         "--batch-size",
 | |
|         type=int,
 | |
|         default=4,
 | |
|         help="Batch size",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--seq-len",
 | |
|         type=int,
 | |
|         default=128,
 | |
|         help="Sequence length",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--hidden-size",
 | |
|         type=int,
 | |
|         default=4096,
 | |
|         help="Hidden size (2nd dimension) of the sequence",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--use-residual", action="store_true", help="Whether to use residual connection"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--save-path",
 | |
|         type=str,
 | |
|         default="./configs/rmsnorm/",
 | |
|         help="Path to save rmsnorm benchmark results",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # Run correctness test
 | |
|     calculate_diff(
 | |
|         batch_size=args.batch_size,
 | |
|         seq_len=args.seq_len,
 | |
|         hidden_size=args.hidden_size,
 | |
|         use_residual=args.use_residual,
 | |
|     )
 | |
| 
 | |
|     # Get the benchmark function with proper use_residual setting
 | |
|     benchmark = get_benchmark(args.use_residual)
 | |
|     # Run performance benchmark
 | |
|     benchmark.run(print_data=True, save_path=args.save_path)
 |