Enhance SamplingParams (#96)

This commit is contained in:
Woosuk Kwon
2023-05-11 15:45:30 -07:00
committed by GitHub
parent 55f8b0a5de
commit 42f1042e1c
7 changed files with 36 additions and 54 deletions

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
import numpy as np
import torch
from cacheflow.master.server import (
from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
@ -15,15 +15,14 @@ from cacheflow.sampling_params import SamplingParams
def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
sampling_params_dict = {
'n': args.n,
'temperature': 0.0 if args.use_beam_search else 1.0,
'top_p': 1.0,
'use_beam_search': args.use_beam_search,
'stop_token_ids': set(),
'max_num_steps': args.output_len,
}
sampling_params = SamplingParams.from_dict(sampling_params_dict)
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
stop_token_ids=set(),
max_tokens=args.output_len,
)
print(sampling_params)
input_token_ids = [0] * args.input_len
@ -31,7 +30,8 @@ def main(args: argparse.Namespace):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
frontend._add_query(input_token_ids, sampling_params)
dummy_prompt = ""
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:

View File

@ -316,7 +316,7 @@ class Scheduler:
continue
# Check if the sequence has reached the maximum number of steps.
max_num_steps = self.sampling_params[group_id].max_num_steps
max_num_steps = self.sampling_params[group_id].max_tokens
if self.num_steps[group_id] == max_num_steps:
self._free_seq(seq)
continue

View File

@ -89,8 +89,8 @@ class FastAPIServer:
async def generate(self, request_dict: Dict):
# Preprocess the request.
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []

View File

@ -367,7 +367,7 @@ def _sample(
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
logprob, sampling_params.logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
@ -392,7 +392,7 @@ def _sample(
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
logprob[i], sampling_params.logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(

View File

@ -5,16 +5,16 @@ class SamplingParams:
def __init__(
self,
n: int,
presence_penalty: float,
frequency_penalty: float,
temperature: float,
top_p: float,
top_k: int,
use_beam_search: bool,
stop_token_ids: Set[int],
max_num_steps: int,
num_logprobs: int,
n: int = 1,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_tokens: int = 16,
logprobs: int = 0,
) -> None:
if n < 1:
raise ValueError(f"n must be at least 1, got {n}.")
@ -32,12 +32,12 @@ class SamplingParams:
if top_k < -1 or top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {top_k}.")
if max_num_steps < 1:
if max_tokens < 1:
raise ValueError(
f"max_num_steps must be at least 1, got {max_num_steps}.")
if num_logprobs < 0:
f"max_tokens must be at least 1, got {max_tokens}.")
if logprobs < 0:
raise ValueError(
f"num_logprobs must be non-negative, got {num_logprobs}.")
f"logprobs must be non-negative, got {logprobs}.")
if use_beam_search:
if n == 1:
@ -72,8 +72,8 @@ class SamplingParams:
self.top_k = top_k
self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs
self.max_tokens = max_tokens
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
@ -84,23 +84,5 @@ class SamplingParams:
f"top_k={self.top_k},"
f"use_beam_search={self.use_beam_search}, "
f"stop_token_ids={self.stop_token_ids}, "
f"max_num_steps={self.max_num_steps}, "
f"num_logprobs={self.num_logprobs}")
@classmethod
def from_dict(cls, d: Dict) -> "SamplingParams":
sampling_params = cls(
n=d.pop("n", 1),
presence_penalty=d.pop("presence_penalty", 0.0),
frequency_penalty=d.pop("frequency_penalty", 0.0),
temperature=d.pop("temperature", 1.0),
top_p=d.pop("top_p", 1.0),
top_k=d.pop("top_k", -1),
use_beam_search=d.pop("use_beam_search", False),
stop_token_ids=set(d.pop("stop_token_ids", set())),
max_num_steps=d.pop("max_num_steps", 16),
num_logprobs=d.pop("num_logprobs", 0),
)
if d:
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
return sampling_params
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}")

View File

@ -10,7 +10,7 @@ def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"max_num_steps": 128,
"max_tokens": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)

View File

@ -18,7 +18,7 @@ def main(args: argparse.Namespace):
while True:
if test_inputs:
text, sampling_params_dict = test_inputs.pop(0)
sampling_params = SamplingParams.from_dict(sampling_params_dict)
sampling_params = SamplingParams(**sampling_params_dict)
sampling_params = frontend.add_eos_token(sampling_params)
frontend.query(text, sampling_params)
server.add_sequence_groups(frontend.get_inputs())