mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	Signed-off-by: Didier Durand <durand.didier@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
		
			
				
	
	
		
			1066 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1066 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # SPDX-License-Identifier: Apache-2.0
 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | |
| 
 | |
| import argparse
 | |
| import copy
 | |
| import json
 | |
| import pickle
 | |
| import time
 | |
| from dataclasses import dataclass
 | |
| from enum import Enum, auto
 | |
| from itertools import product
 | |
| from pathlib import Path
 | |
| from typing import Any, Callable, Optional
 | |
| 
 | |
| import torch
 | |
| import torch.utils.benchmark as TBenchmark
 | |
| from torch.utils.benchmark import Measurement as TMeasurement
 | |
| from utils import ArgPool, Bench, CudaGraphBenchParams
 | |
| from weight_shapes import WEIGHT_SHAPES
 | |
| 
 | |
| from vllm.triton_utils import HAS_TRITON
 | |
| 
 | |
| if HAS_TRITON:
 | |
|     from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
 | |
|     from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
 | |
| 
 | |
| from vllm.utils import FlexibleArgumentParser
 | |
| 
 | |
| DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
 | |
| DEFAULT_TP_SIZES = [1]
 | |
| DEFAULT_BATCH_SIZES = [
 | |
|     1,
 | |
|     16,
 | |
|     32,
 | |
|     64,
 | |
|     128,
 | |
|     192,
 | |
|     256,
 | |
|     320,
 | |
|     384,
 | |
|     448,
 | |
|     512,
 | |
|     640,
 | |
|     768,
 | |
|     896,
 | |
|     1024,
 | |
|     2048,
 | |
|     3072,
 | |
|     4096,
 | |
|     5120,
 | |
|     6144,
 | |
|     7168,
 | |
|     8192,
 | |
| ]
 | |
| DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
 | |
| DEFAULT_LORA_RANKS = [16]
 | |
| DEFAULT_NUM_LORAS = [1, 2, 3, 4]
 | |
| DEFAULT_SORT_BY_LORA_IDS = [False, True]
 | |
| DEFAULT_SEQ_LENGTHS = [1]
 | |
| DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
 | |
| 
 | |
| 
 | |
| # Utilities
 | |
| def dtype_to_str(dtype: torch.dtype):
 | |
|     if dtype == torch.float16:
 | |
|         return "f16"
 | |
|     if dtype == torch.bfloat16:
 | |
|         return "bf16"
 | |
|     if dtype == torch.float32:
 | |
|         return "f32"
 | |
|     raise ValueError(f"Unsupported dtype {dtype}")
 | |
| 
 | |
| 
 | |
| def make_rand_lora_weight_tensor(
 | |
|     k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda"
 | |
| ) -> torch.Tensor:
 | |
|     # LoRA weights column major
 | |
|     return torch.rand((num_loras, n, k), dtype=dtype).to(device)
 | |
| 
 | |
| 
 | |
| def make_rand_tensors(
 | |
|     a_shape: tuple[int],
 | |
|     b_shape: tuple[int],
 | |
|     c_shape: tuple[int],
 | |
|     a_dtype: torch.dtype,
 | |
|     b_dtype: torch.dtype,
 | |
|     c_dtype: torch.dtype,
 | |
|     num_slices: int,
 | |
|     device: str = "cuda",
 | |
| ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
 | |
|     """
 | |
|     Make LoRA input/output matrices.
 | |
|     """
 | |
|     A = torch.rand(a_shape, dtype=a_dtype).to(device)
 | |
| 
 | |
|     # LoRA weights column major
 | |
|     Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)]
 | |
| 
 | |
|     C = torch.zeros(c_shape, dtype=c_dtype).to(device)
 | |
|     return A, Bs, C
 | |
| 
 | |
| 
 | |
| def make_prompt_lora_mapping(
 | |
|     num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str
 | |
| ) -> torch.Tensor:
 | |
|     """
 | |
|     All prompts are mapped to a LoRA ID in range [0, num_active_loras).
 | |
|     where 0 refers to first lora, 1 refers to second lora and so on.
 | |
|     """
 | |
|     assert num_active_loras > 0
 | |
| 
 | |
|     if not sort_by_lora_id:
 | |
|         return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long)
 | |
| 
 | |
|     # Divide LoRAs equally and in order.
 | |
|     part_size = num_prompts // num_active_loras
 | |
|     part_size = max(part_size, 1)
 | |
| 
 | |
|     lora_id = 0
 | |
|     prompt_lora_mapping = []
 | |
|     while len(prompt_lora_mapping) < num_prompts:
 | |
|         prompt_lora_mapping.extend([lora_id] * part_size)
 | |
|         lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
 | |
|     return torch.tensor(
 | |
|         prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device
 | |
|     )
 | |
| 
 | |
| 
 | |
| def make_token_lora_mapping(
 | |
|     num_tokens: int,
 | |
|     num_prompts: int,
 | |
|     prompt_lora_mapping: torch.Tensor,
 | |
|     seq_len_tensor: torch.Tensor,
 | |
|     device: str,
 | |
| ):
 | |
|     """
 | |
|     Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
 | |
|     """
 | |
|     assert prompt_lora_mapping.shape[0] == num_prompts
 | |
| 
 | |
|     # token to lora index mapping
 | |
|     token_lora_mapping = [0] * num_tokens
 | |
|     current_offset = 0
 | |
|     for b_id in range(num_prompts):
 | |
|         lora_index = prompt_lora_mapping[b_id].item()
 | |
|         s = current_offset
 | |
|         e = s + seq_len_tensor[b_id].item()
 | |
|         token_lora_mapping[s:e] = [lora_index] * (e - s)
 | |
|         current_offset += seq_len_tensor[b_id].item()
 | |
| 
 | |
|     return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)
 | |
| 
 | |
| 
 | |
| def ref_group_gemm(
 | |
|     ref_out: torch.Tensor,
 | |
|     input: torch.Tensor,
 | |
|     lora_weights: list[torch.Tensor],
 | |
|     seq_lens_cpu: torch.Tensor,
 | |
|     prompt_lora_mapping_cpu: torch.Tensor,
 | |
|     scaling: float,
 | |
|     add_inputs: Optional[bool],
 | |
| ):
 | |
|     """
 | |
|     Torch group gemm reference implementation to test correctness of
 | |
|     benchmarking operations.
 | |
|     """
 | |
|     batches = seq_lens_cpu.size(0)
 | |
|     out_list = []
 | |
|     current_offset = 0
 | |
|     for lora_index, b_length in zip(range(batches), seq_lens_cpu):
 | |
|         x = input[current_offset : b_length + current_offset, :]
 | |
|         current_offset += b_length
 | |
|         w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
 | |
|         result = torch.nn.functional.linear(x, w)
 | |
|         result *= scaling
 | |
|         out_list.append(result)
 | |
| 
 | |
|     cat_result = torch.cat(out_list, dim=0)
 | |
| 
 | |
|     if add_inputs:
 | |
|         ref_out += cat_result
 | |
|     else:
 | |
|         ref_out.copy_(cat_result)
 | |
| 
 | |
| 
 | |
| class OpType(Enum):
 | |
|     """
 | |
|     LoRA Ops to benchmark and its properties.
 | |
|     """
 | |
| 
 | |
|     LORA_SHRINK = auto()
 | |
|     LORA_EXPAND = auto()
 | |
| 
 | |
|     @staticmethod
 | |
|     def from_str(s: str) -> "OpType":
 | |
|         if s.lower() == "lora_shrink":
 | |
|             return OpType.LORA_SHRINK
 | |
|         if s.lower() == "lora_expand":
 | |
|             return OpType.LORA_EXPAND
 | |
|         raise ValueError(f"Unrecognized str {s} to convert to OpType")
 | |
| 
 | |
|     def is_shrink_fn(self) -> bool:
 | |
|         return self in [OpType.LORA_SHRINK]
 | |
| 
 | |
|     def is_expand_fn(self) -> bool:
 | |
|         return self in [OpType.LORA_EXPAND]
 | |
| 
 | |
|     def num_slices(self) -> list[int]:
 | |
|         return [1, 2, 3]
 | |
| 
 | |
|     def mkn(
 | |
|         self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
 | |
|     ) -> tuple[int, int, int]:
 | |
|         num_tokens = batch_size * seq_length
 | |
|         if self.is_shrink_fn():
 | |
|             m = num_tokens
 | |
|             k = hidden_size
 | |
|             n = lora_rank
 | |
|         else:
 | |
|             assert self.is_expand_fn()
 | |
|             m = num_tokens
 | |
|             k = lora_rank
 | |
|             n = hidden_size
 | |
|         return m, k, n
 | |
| 
 | |
|     def matmul_dtypes(
 | |
|         self, op_dtype: torch.dtype
 | |
|     ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
 | |
|         """
 | |
|         return a type, b type and c type for A x B = C
 | |
|         """
 | |
|         if self.is_shrink_fn():
 | |
|             return op_dtype, op_dtype, torch.float32
 | |
|         else:
 | |
|             assert self.is_expand_fn()
 | |
|             return torch.float32, op_dtype, op_dtype
 | |
| 
 | |
|     def matmul_shapes(
 | |
|         self,
 | |
|         batch_size: int,
 | |
|         seq_length: int,
 | |
|         hidden_size: int,
 | |
|         lora_rank: int,
 | |
|         num_loras: int,
 | |
|         num_slices: int,
 | |
|     ) -> tuple[tuple[int], tuple[int], tuple[int]]:
 | |
|         """
 | |
|         Given num_slices, return the shapes of the A, B, and C matrices
 | |
|         in A x B = C, for the op_type
 | |
|         """
 | |
|         m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
 | |
| 
 | |
|         b_shape = (num_loras, n, k)  # col-major
 | |
|         if self in [OpType.LORA_SHRINK]:
 | |
|             # LoRA shrink kernels support num_slices inherently in the kernel.
 | |
|             return ((m, k), b_shape, (num_slices, m, n))
 | |
|         if self in [OpType.LORA_EXPAND]:
 | |
|             # LoRA expand kernels support num_slices inherently in the kernel
 | |
|             return ((num_slices, m, k), b_shape, (m, n * num_slices))
 | |
|         raise ValueError(f"Unrecognized op_type {self}")
 | |
| 
 | |
|     def bench_fn(self) -> Callable:
 | |
|         if self == OpType.LORA_SHRINK:
 | |
|             return lora_shrink
 | |
|         if self == OpType.LORA_EXPAND:
 | |
|             return lora_expand
 | |
| 
 | |
|         raise ValueError(f"Unrecognized optype {self}")
 | |
| 
 | |
|     def run_ref_group_gemm(
 | |
|         self,
 | |
|         output: torch.Tensor,
 | |
|         input: torch.Tensor,
 | |
|         lora_weights: list[torch.Tensor],
 | |
|         **kwargs,
 | |
|     ) -> Callable:
 | |
|         """Each benchmark operation expects the input, lora_weights and outputs
 | |
|         in a slightly different format. Refer to self.matmul_shapes().
 | |
|         run_ref_group_gemm accounts for those differences in executing a
 | |
|         reference group gemm for correctness testing.
 | |
|         """
 | |
|         w_dtype = lora_weights[0].dtype
 | |
|         num_slices = len(lora_weights)
 | |
|         if self in [OpType.LORA_SHRINK]:
 | |
|             for slice_idx in range(num_slices):
 | |
|                 ref_group_gemm(
 | |
|                     ref_out=output[slice_idx, :],
 | |
|                     input=input,
 | |
|                     lora_weights=lora_weights[slice_idx],
 | |
|                     **kwargs,
 | |
|                 )
 | |
|         elif self in [OpType.LORA_EXPAND]:
 | |
|             hidden_size = lora_weights[0].shape[1]
 | |
|             for slice_idx in range(num_slices):
 | |
|                 slice_offset = slice_idx * hidden_size
 | |
|                 ref_group_gemm(
 | |
|                     ref_out=output[:, slice_offset : slice_offset + hidden_size],
 | |
|                     input=input[slice_idx].clone().to(dtype=w_dtype),
 | |
|                     lora_weights=lora_weights[slice_idx],
 | |
|                     **kwargs,
 | |
|                 )
 | |
|         else:
 | |
|             raise ValueError(f"Unrecognized optype {self}")
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class BenchmarkContext:
 | |
|     """
 | |
|     LoRA benchmark context
 | |
|     """
 | |
| 
 | |
|     batch_size: int
 | |
|     hidden_size: int
 | |
|     num_loras: int
 | |
|     num_active_loras: int
 | |
|     lora_rank: int
 | |
|     sort_by_lora_id: bool
 | |
|     dtype: torch.dtype
 | |
|     seq_length: Optional[int] = None
 | |
|     num_slices: Optional[int] = None  # num_slices for slice based ops
 | |
| 
 | |
|     def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
 | |
|         ctx = copy.copy(self)
 | |
|         ctx.seq_length = seq_length
 | |
|         return ctx
 | |
| 
 | |
|     def with_num_slices(self, num_slices: int) -> "BenchmarkContext":
 | |
|         ctx = copy.copy(self)
 | |
|         ctx.num_slices = num_slices
 | |
|         return ctx
 | |
| 
 | |
|     def bench_label(self) -> str:
 | |
|         return f"lora-{self.dtype}"
 | |
| 
 | |
|     def bench_sublabel(self, op_type: OpType) -> str:
 | |
|         m, k, n = op_type.mkn(
 | |
|             self.batch_size, self.seq_length, self.hidden_size, self.lora_rank
 | |
|         )
 | |
|         desc = {
 | |
|             "bs": self.batch_size,
 | |
|             "sl": self.seq_length,
 | |
|             "m": m,
 | |
|             "k": k,
 | |
|             "n": n,
 | |
|             "num_loras": self.num_loras,
 | |
|             "sort_by_lora": self.sort_by_lora_id,
 | |
|             "num_slices": self.num_slices,
 | |
|         }
 | |
|         return json.dumps(desc)
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class BenchmarkTensors:
 | |
|     """
 | |
|     Input/Output tensors used for benchmarks
 | |
|     """
 | |
| 
 | |
|     # matmul tensors
 | |
|     input: torch.Tensor
 | |
|     lora_weights_lst: list[torch.Tensor]
 | |
|     output: torch.Tensor
 | |
|     # LoRA kernel metadata
 | |
|     lora_kernel_meta: LoRAKernelMeta
 | |
|     # Metadata tensors used in testing correctness
 | |
|     seq_lens: torch.Tensor
 | |
|     prompt_lora_mapping: torch.Tensor
 | |
| 
 | |
|     def io_types(self) -> str:
 | |
|         return (
 | |
|             f"{dtype_to_str(self.input.dtype)}x"
 | |
|             f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
 | |
|             f"{dtype_to_str(self.output.dtype)}"
 | |
|         )
 | |
| 
 | |
|     @staticmethod
 | |
|     def make(
 | |
|         ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
 | |
|     ) -> "BenchmarkTensors":
 | |
|         # Make input / output matmul tensors.
 | |
|         a_shape, b_shape, c_shape = op_type.matmul_shapes(
 | |
|             ctx.batch_size,
 | |
|             ctx.seq_length,
 | |
|             ctx.hidden_size,
 | |
|             ctx.lora_rank,
 | |
|             ctx.num_loras,
 | |
|             ctx.num_slices,
 | |
|         )
 | |
|         a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
 | |
|         input_tensor, lora_weights, output_tensor = make_rand_tensors(
 | |
|             a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices
 | |
|         )
 | |
| 
 | |
|         # Make metadata tensors.
 | |
|         # Keep the metadata tensors in the CPU for further processing if needed.
 | |
|         # The tensors get moved to the GPU before benchmarking.
 | |
|         assert ctx.num_active_loras <= ctx.num_loras
 | |
|         total_tokens = ctx.batch_size * ctx.seq_length
 | |
| 
 | |
|         # Make metadata tensors involved in correctness testing.
 | |
|         # Prepare seq lens tensor
 | |
|         seq_len_tensor = torch.randint(
 | |
|             ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,)
 | |
|         )
 | |
|         assert total_tokens == seq_len_tensor.sum()
 | |
|         # Prepare prompt lora indices tensor
 | |
|         prompt_lora_indices_tensor = make_prompt_lora_mapping(
 | |
|             ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu"
 | |
|         )
 | |
| 
 | |
|         # Make LoRAKernelMeta
 | |
|         token_lora_indices_tensor = make_token_lora_mapping(
 | |
|             total_tokens,
 | |
|             ctx.batch_size,
 | |
|             prompt_lora_indices_tensor,
 | |
|             seq_len_tensor,
 | |
|             "cpu",
 | |
|         )
 | |
|         lora_kernel_meta = LoRAKernelMeta.make(
 | |
|             max_loras=ctx.num_loras,
 | |
|             max_num_tokens=token_lora_indices_tensor.size(0),
 | |
|             device="cpu",
 | |
|         )
 | |
|         lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor)
 | |
| 
 | |
|         return BenchmarkTensors(
 | |
|             input_tensor,
 | |
|             lora_weights,
 | |
|             output_tensor,
 | |
|             lora_kernel_meta,
 | |
|             seq_len_tensor,
 | |
|             prompt_lora_indices_tensor,
 | |
|         )
 | |
| 
 | |
|     def sanity_check(self) -> None:
 | |
|         """
 | |
|         Fails asserts when non-conformality is detected.
 | |
|         """
 | |
|         num_tokens = self.input.shape[-2]
 | |
|         # check metadata tensors
 | |
|         assert torch.sum(self.seq_lens) == num_tokens
 | |
|         num_seqs = self.seq_lens.shape[0]
 | |
|         # assert self.seq_start_loc.shape[0] == num_seqs
 | |
|         assert self.prompt_lora_mapping.shape[0] == num_seqs
 | |
|         assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
 | |
| 
 | |
|     def to_device(self, device: str):
 | |
|         """
 | |
|         Transfer tensors to device if the tensors aren't already on the device
 | |
|         """
 | |
| 
 | |
|         def to_device(tensor: torch.Tensor):
 | |
|             if tensor.device != device:
 | |
|                 tensor = tensor.to(device=device)
 | |
|             return tensor
 | |
| 
 | |
|         self.input = to_device(self.input)
 | |
|         self.output = to_device(self.output)
 | |
|         self.seq_lens = to_device(self.seq_lens)
 | |
|         self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
 | |
|         for i in range(len(self.lora_weights_lst)):
 | |
|             self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
 | |
| 
 | |
|         # LoRA meta
 | |
|         for field_name in LoRAKernelMeta.__dataclass_fields__:
 | |
|             field = getattr(self.lora_kernel_meta, field_name)
 | |
|             assert isinstance(field, torch.Tensor)
 | |
|             setattr(self.lora_kernel_meta, field_name, to_device(field))
 | |
| 
 | |
|     def metadata(self) -> tuple[int, int, int]:
 | |
|         """
 | |
|         Return num_seqs, num_tokens and max_seq_len
 | |
|         """
 | |
|         num_seqs = self.seq_lens.shape[0]
 | |
|         num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
 | |
|         max_seq_len = torch.max(self.seq_lens).item()
 | |
|         num_slices = len(self.lora_weights_lst)
 | |
|         return num_seqs, num_tokens, max_seq_len, num_slices
 | |
| 
 | |
|     def as_lora_shrink_kwargs(self) -> dict[str, Any]:
 | |
|         self.sanity_check()
 | |
|         self.to_device(self.input.device)
 | |
| 
 | |
|         _, num_tokens, _, num_slices = self.metadata()
 | |
| 
 | |
|         # Sanity check matrix shapes.
 | |
|         i_shape, lw_shape, o_shape = (
 | |
|             self.input.shape,
 | |
|             self.lora_weights_lst[0].shape,
 | |
|             self.output.shape,
 | |
|         )
 | |
|         # Expected input shape [num_tokens, hidden_size]
 | |
|         assert len(i_shape) == 2
 | |
|         assert i_shape[0] == num_tokens
 | |
|         hidden_size = i_shape[1]
 | |
|         # Expected lora weight shape [num_loras, lora_rank, hidden_size]
 | |
|         assert len(lw_shape) == 3
 | |
|         assert lw_shape[2] == hidden_size
 | |
|         lora_rank = lw_shape[1]
 | |
|         # Expected output shape [num_slices, num_tokens, lora_rank]
 | |
|         assert len(o_shape) == 3
 | |
|         assert o_shape == (num_slices, num_tokens, lora_rank)
 | |
| 
 | |
|         return {
 | |
|             "inputs": self.input,
 | |
|             "lora_a_weights": self.lora_weights_lst,
 | |
|             "output_tensor": self.output,
 | |
|             "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
 | |
|             "token_indices_sorted_by_lora_ids": (
 | |
|                 self.lora_kernel_meta.token_indices_sorted_by_lora_ids
 | |
|             ),
 | |
|             "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
 | |
|             "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
 | |
|             "lora_ids": self.lora_kernel_meta.active_lora_ids,
 | |
|             "scaling": 1.0,
 | |
|         }
 | |
| 
 | |
|     def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
 | |
|         self.sanity_check()
 | |
|         self.to_device(self.input.device)
 | |
| 
 | |
|         _, num_tokens, _, num_slices = self.metadata()
 | |
| 
 | |
|         # Sanity check matrix shapes.
 | |
|         i_shape, lw_shape, o_shape = (
 | |
|             self.input.shape,
 | |
|             self.lora_weights_lst[0].shape,
 | |
|             self.output.shape,
 | |
|         )
 | |
|         # Expected input shape : [num_slices, num_tokens, lora_rank]
 | |
|         assert len(i_shape) == 3
 | |
|         assert i_shape[0] == num_slices
 | |
|         assert i_shape[1] == num_tokens
 | |
|         lora_rank = i_shape[2]
 | |
|         # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
 | |
|         assert len(lw_shape) == 3
 | |
|         assert lw_shape[2] == lora_rank
 | |
|         hidden_size = lw_shape[1]
 | |
|         # Expected output shape : [num_tokens, hidden_size * num_slices]
 | |
|         assert len(o_shape) == 2
 | |
|         assert o_shape == (num_tokens, hidden_size * num_slices)
 | |
| 
 | |
|         return {
 | |
|             "inputs": self.input,
 | |
|             "lora_b_weights": self.lora_weights_lst,
 | |
|             "output_tensor": self.output,
 | |
|             "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
 | |
|             "token_indices_sorted_by_lora_ids": (
 | |
|                 self.lora_kernel_meta.token_indices_sorted_by_lora_ids
 | |
|             ),
 | |
|             "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
 | |
|             "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
 | |
|             "lora_ids": self.lora_kernel_meta.active_lora_ids,
 | |
|             "offset_start": 0,
 | |
|             "add_inputs": add_inputs,
 | |
|         }
 | |
| 
 | |
|     def bench_fn_kwargs(
 | |
|         self, op_type: OpType, add_inputs: Optional[bool] = None
 | |
|     ) -> dict[str, Any]:
 | |
|         if op_type.is_shrink_fn():
 | |
|             assert add_inputs is None
 | |
|         else:
 | |
|             assert add_inputs is not None
 | |
| 
 | |
|         if op_type == OpType.LORA_SHRINK:
 | |
|             return self.as_lora_shrink_kwargs()
 | |
|         if op_type == OpType.LORA_EXPAND:
 | |
|             return self.as_lora_expand_kwargs(add_inputs)
 | |
|         raise ValueError(f"Unrecognized optype {self}")
 | |
| 
 | |
|     def test_correctness(
 | |
|         self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
 | |
|     ) -> bool:
 | |
|         """
 | |
|         Test correctness of op_type implementation against a grouped gemm
 | |
|         reference implementation.
 | |
|         """
 | |
|         seq_lens_cpu = self.seq_lens.to(device="cpu")
 | |
|         prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu")
 | |
|         ref_output = self.output.clone()
 | |
| 
 | |
|         self.output.zero_()
 | |
|         op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
 | |
| 
 | |
|         op_type.run_ref_group_gemm(
 | |
|             ref_output,
 | |
|             self.input,
 | |
|             self.lora_weights_lst,
 | |
|             seq_lens_cpu=seq_lens_cpu,
 | |
|             prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
 | |
|             scaling=1.0,
 | |
|             add_inputs=expand_fn_add_inputs,
 | |
|         )
 | |
| 
 | |
|         rtol, atol = {
 | |
|             torch.float16: (6e-2, 6e-2),
 | |
|             torch.bfloat16: (6e-2, 6e-2),
 | |
|             torch.float32: (1e-2, 1e-2),
 | |
|         }[self.output.dtype]
 | |
| 
 | |
|         return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)
 | |
| 
 | |
| 
 | |
| def bench_optype(
 | |
|     ctx: BenchmarkContext,
 | |
|     arg_pool_size: int,
 | |
|     op_type: OpType,
 | |
|     cuda_graph_nops: Optional[int] = None,
 | |
|     expand_fn_add_inputs: Optional[bool] = None,
 | |
|     test_correctness: bool = False,
 | |
| ) -> TMeasurement:
 | |
|     assert arg_pool_size >= 1
 | |
|     if op_type.is_shrink_fn():
 | |
|         assert expand_fn_add_inputs is None
 | |
|     else:
 | |
|         assert expand_fn_add_inputs is not None
 | |
| 
 | |
|     # BenchmarkContext -> BenchmarkTensors
 | |
|     bench_tensors: list[BenchmarkTensors] = [
 | |
|         BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
 | |
|     ]
 | |
|     for bt in bench_tensors:
 | |
|         bt.sanity_check()
 | |
| 
 | |
|     # Test correctness of our implementation.
 | |
|     if test_correctness:
 | |
|         assert all(
 | |
|             [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
 | |
|         )
 | |
| 
 | |
|     # BenchmarkTensors -> dict (kwargs)
 | |
|     kwargs_list = [
 | |
|         bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
 | |
|         for bt in bench_tensors
 | |
|     ]
 | |
| 
 | |
|     # Clear LoRA optimization hash-maps.
 | |
|     _LORA_A_PTR_DICT.clear()
 | |
|     _LORA_B_PTR_DICT.clear()
 | |
|     # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
 | |
|     for kwargs in kwargs_list:
 | |
|         op_type.bench_fn()(**kwargs)
 | |
|     torch.cuda.synchronize()
 | |
| 
 | |
|     # Merge into a single kwargs and qualify arguments as ArgPool
 | |
|     kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
 | |
|     for _kwargs in kwargs_list:
 | |
|         for k, v in _kwargs.items():
 | |
|             kwargs[k].values.append(v)
 | |
| 
 | |
|     describe_args = (
 | |
|         f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
 | |
|     )
 | |
|     description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
 | |
| 
 | |
|     cuda_graph_params = None
 | |
|     if cuda_graph_nops:
 | |
|         cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
 | |
|     timer = None
 | |
|     with Bench(
 | |
|         cuda_graph_params,
 | |
|         ctx.bench_label(),
 | |
|         ctx.bench_sublabel(op_type),
 | |
|         description,
 | |
|         op_type.bench_fn(),
 | |
|         **kwargs,
 | |
|     ) as bench:
 | |
|         timer = bench.run()
 | |
|     return timer
 | |
| 
 | |
| 
 | |
| def bench_torch_mm(
 | |
|     ctx: BenchmarkContext,
 | |
|     arg_pool_size: int,
 | |
|     op_type: OpType,
 | |
|     cuda_graph_nops: Optional[int] = None,
 | |
| ) -> TMeasurement:
 | |
|     """
 | |
|     Benchmark basic torch.mm as a roofline.
 | |
| 
 | |
|     When all the input tokens have the same LoRA ID, the LoRA kernels are just
 | |
|     a matmul. This torch.mm benchmark serves as a roofline for that case.
 | |
| 
 | |
|     input op_type is used in determining the m, k, n dimensions for the matmul.
 | |
|     """
 | |
| 
 | |
|     batch_size, hidden_size, lora_rank, seq_length, dtype = (
 | |
|         ctx.batch_size,
 | |
|         ctx.hidden_size,
 | |
|         ctx.lora_rank,
 | |
|         ctx.seq_length,
 | |
|         ctx.dtype,
 | |
|     )
 | |
| 
 | |
|     m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
 | |
|     # For a fairer comparison.
 | |
|     n = n * ctx.num_slices
 | |
| 
 | |
|     # Get matmul input and output tensors for A x B = C
 | |
|     As, Bs, Cs = [], [], []
 | |
|     for _ in range(arg_pool_size):
 | |
|         As.append(torch.rand((m, k), dtype=dtype).to("cuda"))
 | |
|         Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t())
 | |
|         Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))
 | |
| 
 | |
|     # Make torch.mm kwargs
 | |
|     mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)}
 | |
| 
 | |
|     description = (
 | |
|         f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
 | |
|         f"x{dtype_to_str(dtype)}"
 | |
|         f"=>{dtype_to_str(dtype)})"
 | |
|     )
 | |
|     cuda_graph_params = None
 | |
|     if cuda_graph_nops:
 | |
|         cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
 | |
|     with Bench(
 | |
|         cuda_graph_params,
 | |
|         ctx.bench_label(),
 | |
|         ctx.bench_sublabel(op_type),
 | |
|         description,
 | |
|         torch.mm,
 | |
|         **mm_kwargs,
 | |
|     ) as bench:
 | |
|         return bench.run()
 | |
| 
 | |
| 
 | |
| # runner
 | |
| def use_cuda_graph_recommendation() -> str:
 | |
|     return """
 | |
|             Triton kernels have a significant launch overhead with
 | |
|             launched directly via python. This overhead is more noticeable
 | |
|             for small the problem sizes. For these cases, it is recommended
 | |
|             to use the script with `--cuda-graph-nops N` to benchmark N
 | |
|             consecutive invocations of the benchmarking operations from 
 | |
|             inside a CUDA Graph. Note that the returned measurement is for N 
 | |
|             invocations of the operation.
 | |
|             """
 | |
| 
 | |
| 
 | |
| def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
 | |
|     compare = TBenchmark.Compare(timers)
 | |
|     compare.print()
 | |
| 
 | |
|     if args and args.cuda_graph_nops:
 | |
|         print(
 | |
|             f"Note : The timings reported above is for {args.cuda_graph_nops} "
 | |
|             "consecutive invocations of the benchmarking functions. "
 | |
|             f"Please divide by {args.cuda_graph_nops} for single invocation "
 | |
|             "timings."
 | |
|         )
 | |
| 
 | |
|     print(
 | |
|         "Note on Comparison with torch.mm : The torch.mm numbers are "
 | |
|         "benchmark numbers of a simple matmul emulating the single lora "
 | |
|         "case. It is provided as a roofline for comparing our LoRA Kernel "
 | |
|         "implementations. It is expected that the LoRA kernels will be "
 | |
|         "slower than torch.mm in cases where num_loras is big. But for "
 | |
|         "small num_loras the goal should be to match the torch.mm numbers."
 | |
|     )
 | |
| 
 | |
| 
 | |
| def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
 | |
|     if args.cuda_graph_nops is not None:
 | |
|         assert args.cuda_graph_nops > 0
 | |
|         print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph")
 | |
|     else:
 | |
|         print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")
 | |
| 
 | |
|     timers = []
 | |
|     for bench_ctx in bench_ctxs:
 | |
|         for seq_len in args.seq_lengths:
 | |
|             bench_ops: list[OpType] = args.op_types
 | |
|             seq_len_timers = []
 | |
|             for bench_op in bench_ops:
 | |
|                 for num_slices in bench_op.num_slices():
 | |
|                     _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
 | |
|                         num_slices
 | |
|                     )
 | |
|                     # Benchmark torch.mm as a roofline
 | |
|                     seq_len_timers.append(
 | |
|                         bench_torch_mm(
 | |
|                             _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|                     # Benchmark bench_op
 | |
|                     expand_fn_add_inputs = (
 | |
|                         [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
 | |
|                     )
 | |
|                     for add_input_arg in expand_fn_add_inputs:
 | |
|                         seq_len_timers.append(
 | |
|                             bench_optype(
 | |
|                                 _ctx,
 | |
|                                 args.arg_pool_size,
 | |
|                                 bench_op,
 | |
|                                 args.cuda_graph_nops,
 | |
|                                 add_input_arg,
 | |
|                                 args.test_correctness,
 | |
|                             )
 | |
|                         )
 | |
| 
 | |
|             print_timers(seq_len_timers)
 | |
|             timers.extend(seq_len_timers)
 | |
| 
 | |
|     # Result stdout dump
 | |
|     print("== All Results ====")
 | |
|     print_timers(timers, args)
 | |
| 
 | |
|     if args.output_directory:
 | |
|         # Result file dump
 | |
|         od = Path(args.output_directory)
 | |
|         if not od.exists():
 | |
|             od.mkdir()
 | |
| 
 | |
|         timestamp = int(time.time())
 | |
|         pkl_file = od / f"lora_bench-{timestamp}.pkl"
 | |
|         print(f"Writing benchmarks to {pkl_file}")
 | |
|         with open(pkl_file, "wb") as f:
 | |
|             pickle.dump(timers, f)
 | |
| 
 | |
| 
 | |
| def as_benchmark_contexts(
 | |
|     hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
 | |
| ) -> list[BenchmarkContext]:
 | |
|     ctxs: list[BenchmarkContext] = []
 | |
|     for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product(  # noqa
 | |
|         args.batch_sizes,
 | |
|         list(hidden_sizes),
 | |
|         lora_ranks,
 | |
|         args.num_loras,
 | |
|         args.sort_by_lora_id,
 | |
|     ):
 | |
|         ctxs.append(
 | |
|             BenchmarkContext(
 | |
|                 batch_size=batch_size,
 | |
|                 hidden_size=hidden_size,
 | |
|                 lora_rank=lora_rank,
 | |
|                 num_loras=num_loras,
 | |
|                 num_active_loras=args.num_active_loras
 | |
|                 if args.num_active_loras
 | |
|                 else num_loras,
 | |
|                 # To be filled based on the OpType to benchmark
 | |
|                 seq_length=None,
 | |
|                 sort_by_lora_id=sort_by_lora_id,
 | |
|                 dtype=args.dtype,
 | |
|                 # To be filled based on the OpType to benchmark
 | |
|                 num_slices=None,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     return ctxs
 | |
| 
 | |
| 
 | |
| def run_list_bench(args: argparse.Namespace):
 | |
|     print(args)
 | |
| 
 | |
|     print(
 | |
|         "List bench :\n"
 | |
|         f"  Hidden Sizes {args.hidden_sizes}"
 | |
|         f"  LoRA Ranks {args.lora_ranks}"
 | |
|     )
 | |
| 
 | |
|     # Get all benchmarking contexts
 | |
|     bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
 | |
|         hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args
 | |
|     )
 | |
| 
 | |
|     run(args, bench_contexts)
 | |
| 
 | |
| 
 | |
| def run_range_bench(args: argparse.Namespace):
 | |
|     print(args)
 | |
| 
 | |
|     hidden_sizes = list(
 | |
|         range(
 | |
|             args.hidden_sizes_start,
 | |
|             args.hidden_sizes_end + 1,
 | |
|             args.hidden_sizes_increment,
 | |
|         )
 | |
|     )
 | |
|     lora_ranks = list(
 | |
|         range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)
 | |
|     )
 | |
| 
 | |
|     print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}")
 | |
| 
 | |
|     # Get all benchmarking contexts
 | |
|     bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
 | |
|         hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args
 | |
|     )
 | |
| 
 | |
|     run(args, bench_contexts)
 | |
| 
 | |
| 
 | |
| def run_model_bench(args: argparse.Namespace):
 | |
|     print(args)
 | |
| 
 | |
|     def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
 | |
|         hidden_sizes = set()
 | |
|         for KN, tp_split_dim in WEIGHT_SHAPES[model]:
 | |
|             KN[tp_split_dim] = KN[tp_split_dim] // tp_size
 | |
|             hidden_sizes.add(KN[1])
 | |
|         return hidden_sizes
 | |
| 
 | |
|     # Get all hidden sizes
 | |
|     hidden_sizes: set[int] = set()
 | |
|     for model_name, tp_size in product(args.models, args.tp_sizes):
 | |
|         hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size))
 | |
| 
 | |
|     print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}")
 | |
| 
 | |
|     # Get all benchmarking contexts
 | |
|     bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
 | |
|         hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args
 | |
|     )
 | |
| 
 | |
|     run(args, bench_contexts)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     def to_torch_dtype(dt):
 | |
|         if dt == "torch.float16":
 | |
|             return torch.float16
 | |
|         if dt == "torch.bfloat16":
 | |
|             return torch.bfloat16
 | |
|         raise ValueError("unsupported dtype")
 | |
| 
 | |
|     def get_bool(s: str) -> bool:
 | |
|         return s.lower() in ["true", "1"]
 | |
| 
 | |
|     def add_common_command_args(p: argparse.ArgumentParser):
 | |
|         p.add_argument(
 | |
|             "--dtype",
 | |
|             type=to_torch_dtype,
 | |
|             required=True,
 | |
|             help="Available options are ['torch.float16', 'torch.bfloat16']",
 | |
|         )
 | |
| 
 | |
|         p.add_argument(
 | |
|             "--arg-pool-size",
 | |
|             type=int,
 | |
|             default=32,
 | |
|             help="Run profiles with a pool of input/output/meta tensors instead"
 | |
|             "of simply reusing the same tensors for all runs. A bigger arg-pool"
 | |
|             "mitigates hardware caching effects during benchmarking.",
 | |
|         )
 | |
| 
 | |
|         p.add_argument(
 | |
|             "--cuda-graph-nops",
 | |
|             type=int,
 | |
|             help=(
 | |
|                 "when set profiling is done using cudagraph, "
 | |
|                 "with the given number of operations in a graph."
 | |
|                 "Note that the measurement returned is the time "
 | |
|                 "taken for N consecutive executions of the benchmarking "
 | |
|                 "functions, where N is the value of this argument."
 | |
|             ),
 | |
|         )
 | |
|         p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS)
 | |
|         p.add_argument(
 | |
|             "--num-active-loras",
 | |
|             type=int,
 | |
|             default=None,
 | |
|             help="Active LoRAs. When None, all LoRAs are active",
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "--sort-by-lora-id",
 | |
|             nargs="+",
 | |
|             type=get_bool,
 | |
|             default=DEFAULT_SORT_BY_LORA_IDS,
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "--op-types", nargs="+", type=OpType.from_str, default=list(OpType)
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "--expand-fn-add-inputs",
 | |
|             nargs="+",
 | |
|             type=get_bool,
 | |
|             default=DEFAULT_EXPAND_FN_ADD_INPUTS,
 | |
|         )
 | |
|         p.add_argument(
 | |
|             "-o",
 | |
|             "--output-directory",
 | |
|             type=str,
 | |
|             help=(
 | |
|                 "Output directory to store a the list of benchmarking"
 | |
|                 "TMeasurement objects as a pickle file"
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|         p.add_argument(
 | |
|             "--test-correctness",
 | |
|             action="store_true",
 | |
|             help=(
 | |
|                 "When enabled, the benchmarking functions are tested"
 | |
|                 "for correctness before the actual benchmarking"
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     parser = FlexibleArgumentParser(
 | |
|         description=f"""
 | |
| Benchmark LoRA kernels:
 | |
|     {use_cuda_graph_recommendation()}
 | |
| 
 | |
|     list_bench example:
 | |
|         python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
 | |
| 
 | |
|     model_bench example:
 | |
|         python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16  --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 
 | |
| 
 | |
|     range_bench example:
 | |
|         python3 benchmarks/kernels/benchmark_lora.py range_bench  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16   --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 
 | |
|             """,  # noqa: E501
 | |
|         formatter_class=argparse.RawTextHelpFormatter,
 | |
|     )
 | |
| 
 | |
|     subparsers = parser.add_subparsers(dest="cmd", required=True)
 | |
| 
 | |
|     list_parser = subparsers.add_parser("list_bench")
 | |
|     list_parser.add_argument(
 | |
|         "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES
 | |
|     )
 | |
|     list_parser.add_argument(
 | |
|         "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
 | |
|     )
 | |
|     add_common_command_args(list_parser)
 | |
|     list_parser.set_defaults(func=run_list_bench)
 | |
| 
 | |
|     range_parser = subparsers.add_parser("range_bench")
 | |
|     range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
 | |
|     range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
 | |
|     range_parser.add_argument("--hidden-sizes-increment", type=int, required=True)
 | |
|     range_parser.add_argument("--lora-ranks-start", type=int, required=True)
 | |
|     range_parser.add_argument("--lora-ranks-end", type=int, required=True)
 | |
|     range_parser.add_argument("--lora-ranks-increment", type=int, required=True)
 | |
|     add_common_command_args(range_parser)
 | |
|     range_parser.set_defaults(func=run_range_bench)
 | |
| 
 | |
|     model_parser = subparsers.add_parser("model_bench")
 | |
|     model_parser.add_argument(
 | |
|         "--models",
 | |
|         nargs="+",
 | |
|         type=str,
 | |
|         default=DEFAULT_MODELS,
 | |
|         choices=WEIGHT_SHAPES.keys(),
 | |
|     )
 | |
|     model_parser.add_argument(
 | |
|         "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
 | |
|     )
 | |
|     model_parser.add_argument(
 | |
|         "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
 | |
|     )
 | |
|     add_common_command_args(model_parser)
 | |
|     model_parser.set_defaults(func=run_model_bench)
 | |
| 
 | |
|     args = parser.parse_args()
 | |
|     args.func(args)
 |