Added dtype arg to benchmarks (#1228)

This commit is contained in:
kg6-sleipnir
2023-10-01 00:04:03 -04:00
committed by GitHub
parent 0967102c6d
commit b5a10eb0ef
2 changed files with 22 additions and 1 deletions

View File

@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
)
sampling_params = SamplingParams(
@ -87,5 +88,14 @@ if __name__ == '__main__':
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
args = parser.parse_args()
main(args)

View File

@ -64,6 +64,7 @@ def run_vllm(
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
) -> float:
llm = LLM(
model=model,
@ -72,6 +73,7 @@ def run_vllm(
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
)
# Add the requests to the engine.
@ -171,7 +173,7 @@ def main(args: argparse.Namespace):
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
args.trust_remote_code, args.dtype)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -219,6 +221,15 @@ if __name__ == "__main__":
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
args = parser.parse_args()
if args.backend == "vllm":