mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-26 10:54:33 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			101 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # SPDX-License-Identifier: Apache-2.0
 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | |
| 
 | |
| # Cutlass bench utils
 | |
| from collections.abc import Iterable
 | |
| 
 | |
| import torch
 | |
| 
 | |
| import vllm._custom_ops as ops
 | |
| 
 | |
| 
 | |
| def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
 | |
|     finfo = torch.finfo(torch.float8_e4m3fn)
 | |
|     return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
 | |
|         dtype=torch.float8_e4m3fn
 | |
|     )
 | |
| 
 | |
| 
 | |
| def to_int8(tensor: torch.Tensor) -> torch.Tensor:
 | |
|     return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
 | |
| 
 | |
| 
 | |
| def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
 | |
|     return tensor.to(dtype=torch.bfloat16)
 | |
| 
 | |
| 
 | |
| def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
 | |
|     return tensor.to(dtype=torch.float16)
 | |
| 
 | |
| 
 | |
| def make_rand_tensors(
 | |
|     dtype: torch.dtype, m: int, n: int, k: int
 | |
| ) -> tuple[torch.Tensor, torch.Tensor]:
 | |
|     a = torch.randn((m, k), device="cuda") * 5
 | |
|     b = torch.randn((n, k), device="cuda").t() * 5
 | |
| 
 | |
|     if dtype == torch.int8:
 | |
|         return to_int8(a), to_int8(b)
 | |
|     if dtype == torch.float8_e4m3fn:
 | |
|         return to_fp8(a), to_fp8(b)
 | |
| 
 | |
|     raise ValueError("unsupported dtype")
 | |
| 
 | |
| 
 | |
| def prune_to_2_4(tensor):
 | |
|     # Reshape tensor to [N, 4] where N is number of groups of 4
 | |
|     original_shape = tensor.shape
 | |
|     reshaped = tensor.reshape(-1, 4)
 | |
| 
 | |
|     # Get indices of top 2 absolute values in each group of 4
 | |
|     _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
 | |
| 
 | |
|     # Create binary mask
 | |
|     mask = torch.zeros_like(reshaped)
 | |
|     mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
 | |
| 
 | |
|     # Apply mask and reshape back
 | |
|     pruned = reshaped * mask
 | |
| 
 | |
|     # Turn all -0.0 to 0.0
 | |
|     pruned[pruned == -0.0] = 0.0
 | |
| 
 | |
|     return pruned.reshape(original_shape)
 | |
| 
 | |
| 
 | |
| def make_rand_sparse_tensors(
 | |
|     dtype: torch.dtype, m: int, n: int, k: int
 | |
| ) -> tuple[torch.Tensor, torch.Tensor]:
 | |
|     a = torch.randn((m, k), device="cuda") * 5
 | |
|     b = torch.randn((n, k), device="cuda").t() * 5
 | |
| 
 | |
|     b = prune_to_2_4(b.t()).t()
 | |
| 
 | |
|     if dtype == torch.int8:
 | |
|         a, b = to_int8(a), to_int8(b)
 | |
|     elif dtype == torch.float8_e4m3fn:
 | |
|         a, b = to_fp8(a), to_fp8(b)
 | |
|     elif dtype == torch.float16:
 | |
|         a, b = to_fp16(a), to_fp16(b)
 | |
|     elif dtype == torch.bfloat16:
 | |
|         a, b = to_bf16(a), to_bf16(b)
 | |
|     else:
 | |
|         raise ValueError("unsupported dtype")
 | |
| 
 | |
|     b_compressed, e = ops.cutlass_sparse_compress(b.t())
 | |
| 
 | |
|     # Compressed B, Metadata, Original A, B
 | |
|     return b_compressed, e, a, b
 | |
| 
 | |
| 
 | |
| def make_n_rand_sparse_tensors(
 | |
|     num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
 | |
| ) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
 | |
|     ABs = []
 | |
|     for _ in range(num_tensors):
 | |
|         b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
 | |
|         if b_comp is not None:
 | |
|             ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
 | |
|     BComps, Es, As, Bs = zip(*ABs)
 | |
|     return list(BComps), list(Es), list(As), list(Bs)
 |