mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 09:44:02 +08:00
Compare commits
21 Commits
v4.56.0
...
tool-handl
Author | SHA1 | Date | |
---|---|---|---|
825c8e5695 | |||
8735fae88b | |||
be99c3d651 | |||
a729324a73 | |||
fa30ee577d | |||
f8d7972e9d | |||
61d7189903 | |||
2a675dc552 | |||
40fe9fa4bd | |||
89e4f912e8 | |||
dfd6a7360e | |||
d4ba83051f | |||
d152e9e179 | |||
154f4c144d | |||
f043c82818 | |||
8219b66d2f | |||
1b5172843d | |||
070aab3a09 | |||
482fed4837 | |||
4173cfea98 | |||
f6024b723c |
2
setup.py
2
setup.py
@ -148,7 +148,7 @@ _deps = [
|
||||
"protobuf",
|
||||
"psutil",
|
||||
"pyyaml>=5.1",
|
||||
"pydantic",
|
||||
"pydantic>=2",
|
||||
"pytest>=7.2.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-rerunfailures",
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@ -24,22 +24,20 @@ import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import yaml
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
TextIteratorStreamer,
|
||||
logging,
|
||||
)
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||
from transformers.utils import is_rich_available, is_torch_available
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import pwd
|
||||
@ -52,8 +50,12 @@ if is_rich_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
ALLOWED_VALUE_CHARS = set(
|
||||
@ -133,21 +135,21 @@ class RichInterface:
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
|
||||
"""Stream output from a role, and return the generated text after it's done steaming."""
|
||||
# This method is originally from the FastChat CLI:
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
text = ""
|
||||
async for token in await stream:
|
||||
outputs = token.choices[0].delta.content
|
||||
request_id = token.id
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
||||
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
||||
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
|
||||
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
@ -160,6 +162,7 @@ class RichInterface:
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
@ -169,11 +172,15 @@ class RichInterface:
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
live.update(markdown, refresh=True)
|
||||
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
return text, request_id
|
||||
|
||||
def input(self) -> str:
|
||||
"""Gets user input from the console."""
|
||||
@ -300,6 +307,10 @@ class ChatArguments:
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
|
||||
def chat_command_factory(args: Namespace):
|
||||
"""
|
||||
@ -322,7 +333,10 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
group = chat_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument(
|
||||
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
||||
"model_name_or_path_or_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the pre-trained model or address to connect to.",
|
||||
)
|
||||
group.add_argument(
|
||||
"generate_flags",
|
||||
@ -340,6 +354,36 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
def __init__(self, args):
|
||||
args = self._handle_deprecated_args(args)
|
||||
|
||||
if args.model_name_or_path_or_address is not None:
|
||||
name = args.model_name_or_path_or_address
|
||||
if name.startswith("http") or name.startswith("https") or name.startswith("localhost"):
|
||||
self.spawn_backend = False
|
||||
|
||||
if args.host != "localhost" or args.port != 8000:
|
||||
raise ValueError(
|
||||
"Looks like you’ve set both a server address and a custom host/port. "
|
||||
"Please pick just one way to specify the server."
|
||||
)
|
||||
|
||||
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
||||
else:
|
||||
self.spawn_backend = True
|
||||
args.model_name_or_path = args.model_name_or_path_or_address
|
||||
|
||||
if not is_rich_available() and (not is_torch_available() and self.spawn_backend):
|
||||
raise ImportError(
|
||||
"You need to install rich to use the chat interface. Additionally, you have not specified a remote "
|
||||
"endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)"
|
||||
)
|
||||
elif not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
elif not is_torch_available() and self.spawn_backend:
|
||||
raise ImportError(
|
||||
"You have not specified a remote endpoint and are therefore spawning a backend. Torch is required "
|
||||
"for this: (`pip install rich torch`)"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
||||
@ -349,22 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
"""
|
||||
has_warnings = False
|
||||
|
||||
# 1. Model as a positional argument
|
||||
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
||||
if args.model_name_or_path_positional is None:
|
||||
raise ValueError(
|
||||
"One of the following must be provided:"
|
||||
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
||||
)
|
||||
elif args.model_name_or_path is not None:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
||||
"argument instead, e.g. `transformers chat <model_repo>`.",
|
||||
FutureWarning,
|
||||
)
|
||||
# 2. Named generate option args
|
||||
# Named generate option args
|
||||
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
||||
value = getattr(args, deprecated_arg)
|
||||
if value != default_value:
|
||||
@ -404,7 +433,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
if filename is None:
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
||||
filename = f"{args.model_name_or_path_or_address}/chat_{time_str}.json"
|
||||
filename = os.path.join(folder, filename)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
@ -477,40 +506,20 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
return processed_generate_flags
|
||||
|
||||
def get_generation_parameterization(
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
def get_generation_parameterization(self, args: ChatArguments) -> tuple[GenerationConfig, dict]:
|
||||
"""
|
||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||
"""
|
||||
# No generation config arg provided -> use default generation config, apply CLI defaults
|
||||
if args.generation_config is None:
|
||||
# We start off from the checkpoint's generation config
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
# Apply deprecated CLI args on top of the default generation config
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(
|
||||
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
|
||||
)
|
||||
deprecated_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"do_sample": args.do_sample,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
}
|
||||
generation_config.update(**deprecated_kwargs)
|
||||
# generation config arg provided -> use it as the base parameterization
|
||||
else:
|
||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||
if args.generation_config is not None:
|
||||
if ".json" in args.generation_config: # is a local file
|
||||
dirname = os.path.dirname(args.generation_config)
|
||||
filename = os.path.basename(args.generation_config)
|
||||
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
||||
else:
|
||||
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
||||
else:
|
||||
generation_config = GenerationConfig()
|
||||
|
||||
# Finally: parse and apply `generate_flags`
|
||||
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||
@ -664,7 +673,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
elif user_input == "!status":
|
||||
interface.print_status(
|
||||
model_name=args.model_name_or_path_positional,
|
||||
model_name=args.model_name_or_path,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
@ -679,10 +688,33 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Main logic
|
||||
def run(self):
|
||||
if not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
if not is_torch_available():
|
||||
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)")
|
||||
asyncio.run(self._inner_run())
|
||||
|
||||
async def _inner_run(self):
|
||||
if self.spawn_backend:
|
||||
serve_args = ServeArguments(
|
||||
model_revision=self.args.model_revision,
|
||||
device=self.args.device,
|
||||
torch_dtype=self.args.torch_dtype,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
attn_implementation=self.args.attn_implementation,
|
||||
load_in_8bit=self.args.load_in_8bit,
|
||||
load_in_4bit=self.args.load_in_4bit,
|
||||
bnb_4bit_quant_type=self.args.bnb_4bit_quant_type,
|
||||
use_bnb_nested_quant=self.args.use_bnb_nested_quant,
|
||||
host=self.args.host,
|
||||
port=self.args.port,
|
||||
log_level="error",
|
||||
)
|
||||
serve_args.model_name_or_path = self.args.model_name_or_path
|
||||
serve_command = ServeCommand(serve_args)
|
||||
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
host = "http://localhost" if self.args.host == "localhost" else self.args.host
|
||||
client = AsyncInferenceClient(f"{host}:{self.args.port}")
|
||||
|
||||
args = self.args
|
||||
if args.examples_path is None:
|
||||
@ -696,19 +728,14 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args)
|
||||
|
||||
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
||||
if not args.verbose:
|
||||
logging.set_verbosity_error()
|
||||
disable_progress_bars()
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = self.clear_chat_history(args.system_prompt)
|
||||
|
||||
request_id = None
|
||||
|
||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||
interface.print_help(minimal=True)
|
||||
while True:
|
||||
@ -736,23 +763,25 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
stream = client.chat_completion(
|
||||
chat,
|
||||
stream=True,
|
||||
extra_body={"request_id": request_id, "generation_config": {**generation_config.to_dict()}},
|
||||
)
|
||||
attention_mask = torch.ones_like(inputs)
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
**model_kwargs,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
model_output, request_id = await interface.stream_output(stream)
|
||||
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = ChatArguments()
|
||||
args.model_name_or_path_or_address = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
args.model_name_or_path_or_address = "http://localhost:8000"
|
||||
chat = ChatCommand(args)
|
||||
chat.run()
|
||||
|
@ -12,32 +12,75 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
||||
from ..utils import logging
|
||||
from huggingface_hub import ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputFunction
|
||||
|
||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||
|
||||
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
|
||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||
from ..utils import is_torch_available, logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import Body, FastAPI, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available():
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
from uvicorn import run
|
||||
|
||||
_serve_dependencies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def Body(*x, **y):
|
||||
pass
|
||||
class ChatCompletionInput(BaseModel):
|
||||
messages: list[Message]
|
||||
|
||||
_serve_dependencies_installed = False
|
||||
stream: Optional[bool] = False
|
||||
model: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
extra_body: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[list[float]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
# Additional options supported by the HFH InferenceClient
|
||||
# that aren't yet supported here.
|
||||
|
||||
# logprobs: Optional[bool] = None
|
||||
tools: Any = None
|
||||
# n: Optional[int] = None
|
||||
# presence_penalty: Optional[float] = None
|
||||
# response_format: Optional[ChatCompletionInputGrammarType] = None
|
||||
# stream_options: Optional[ChatCompletionInputStreamOptions] = None
|
||||
# tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
|
||||
# tool_prompt: Optional[str] = None
|
||||
# top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
logger = logging.get_logger("transformers/serving")
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
@ -46,47 +89,88 @@ def serve_command_factory(args: Namespace):
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
return ServeCommand(args)
|
||||
|
||||
|
||||
def create_generation_config_from_req(req: "ChatCompletionInput"):
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
for key in req.extra_body["generation_config"].keys():
|
||||
if key in ChatCompletionInput.base_field_names.keys():
|
||||
return {"error": "Duplicated key in the root request and in the passed generation config."}
|
||||
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
|
||||
else:
|
||||
generation_config = GenerationConfig()
|
||||
|
||||
if req.frequency_penalty is not None:
|
||||
generation_config.repetition_penalty = req.frequency_penalty
|
||||
if req.logit_bias is not None:
|
||||
generation_config.sequence_bias = req.logit_bias
|
||||
if req.stop is not None:
|
||||
generation_config.stop_strings = req.stop
|
||||
if req.temperature is not None:
|
||||
generation_config.temperature = req.temperature
|
||||
if req.top_p is not None:
|
||||
generation_config.top_p = req.top_p
|
||||
if req.seed is not None:
|
||||
torch.manual_seed(req.seed)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeArguments:
|
||||
r"""
|
||||
Arguments for the serve CLI.
|
||||
|
||||
See the metadata arg for each argument's description -- the metadata will be printed with
|
||||
`transformers serve --help`
|
||||
"""
|
||||
|
||||
# Model loading
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port, args.workers)
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
|
||||
"the dtype will be automatically derived from the model's weights.",
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
|
||||
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
class ServeModelInfoResult(BaseModel):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
|
||||
infos: dict
|
||||
|
||||
|
||||
class ServeTokenizeResult(BaseModel):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
|
||||
tokens: list[str]
|
||||
tokens_ids: Optional[list[int]]
|
||||
|
||||
|
||||
class ServeDeTokenizeResult(BaseModel):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardResult(BaseModel):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
|
||||
output: Any
|
||||
# Other settings
|
||||
log_level: str = field(
|
||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||
)
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
@ -98,131 +182,248 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
Args:
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
serve_parser = parser.add_parser(
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
choices=get_supported_tasks(),
|
||||
help="The task to run the pipeline on",
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
serve_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
dataclass_types = (ServeArguments,)
|
||||
serve_parser = parser.add_parser("serve", dataclass_types=dataclass_types)
|
||||
|
||||
group = serve_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument("model_name_or_path", type=str, default=None, help="Name of the pre-trained model.")
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
||||
self._pipeline = pipeline
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.workers = workers
|
||||
|
||||
if not _serve_dependencies_installed:
|
||||
raise RuntimeError(
|
||||
"Using serve command requires FastAPI and uvicorn. "
|
||||
'Please install transformers with [serving]: pip install "transformers[serving]". '
|
||||
"Or install FastAPI and uvicorn separately."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Serving model over {host}:{port}")
|
||||
self._app = FastAPI(
|
||||
routes=[
|
||||
APIRoute(
|
||||
"/",
|
||||
self.model_info,
|
||||
response_model=ServeModelInfoResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["GET"],
|
||||
),
|
||||
APIRoute(
|
||||
"/tokenize",
|
||||
self.tokenize,
|
||||
response_model=ServeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/detokenize",
|
||||
self.detokenize,
|
||||
response_model=ServeDeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/forward",
|
||||
self.forward,
|
||||
response_model=ServeForwardResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
],
|
||||
timeout=600,
|
||||
def __init__(self, args: ServeArguments):
|
||||
if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available():
|
||||
raise ImportError(
|
||||
"Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(args)
|
||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||
|
||||
transformers_logger = logging.get_logger("transformers")
|
||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
def build_chunk(self, content: str, request_id: str, finish_reason: Optional[str] = None, tool_calls=None) -> str:
|
||||
print(content)
|
||||
print(tool_calls)
|
||||
payload = {
|
||||
"object": "chat.completion.chunk",
|
||||
"id": request_id,
|
||||
"created": int(time.time()),
|
||||
"model": self.args.model_name_or_path,
|
||||
"system_fingerprint": "",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": content, "tool_calls": tool_calls or []},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
print(payload)
|
||||
print(f"data: {json.dumps(payload)}\n\n")
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
||||
app = FastAPI()
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||
if self.use_continuous_batching:
|
||||
self.continuous_batching(app)
|
||||
else:
|
||||
self.generate(app)
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
"""
|
||||
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
||||
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
||||
mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
else:
|
||||
return ServeTokenizeResult(tokens=tokens_txt)
|
||||
def continuous_batching(self, app):
|
||||
generation_config = GenerationConfig(
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=1,
|
||||
block_size=1024,
|
||||
do_sample=False,
|
||||
max_batch_tokens=10,
|
||||
scheduler="fifo",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
|
||||
generation_config=generation_config, streaming=True
|
||||
)
|
||||
manager.start()
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens_ids: list[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||
):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
||||
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
||||
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: ChatCompletionInput):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
async def forward(self, inputs=Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**: **attention_mask**: **tokens_type_ids**:
|
||||
"""
|
||||
chat = req.messages
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(output=[], attention=[])
|
||||
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
self.model.device
|
||||
)
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
output = self._pipeline(inputs)
|
||||
return ServeForwardResult(output=output)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
|
||||
def stream_response(_inputs):
|
||||
try:
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
|
||||
queue_is_flushed = False
|
||||
|
||||
for result in manager:
|
||||
if req.request_id is not None and not queue_is_flushed:
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
continue
|
||||
else:
|
||||
queue_is_flushed = True
|
||||
|
||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
|
||||
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
break
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
|
||||
|
||||
def generate(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: ChatCompletionInput):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||
|
||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
generation_config.max_new_tokens = max_new_tokens
|
||||
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": torch.ones_like(inputs),
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
}
|
||||
|
||||
def stream_response(streamer, _request_id):
|
||||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||
|
||||
try:
|
||||
thread.start()
|
||||
|
||||
inside_tool_call = False
|
||||
tool_call_buffer = ""
|
||||
|
||||
for result in streamer:
|
||||
print("Chunk:", result)
|
||||
|
||||
if result.strip() == "<tool_call>":
|
||||
inside_tool_call = True
|
||||
tool_call_buffer = ""
|
||||
continue
|
||||
|
||||
if inside_tool_call:
|
||||
if result.strip() == "</tool_call>":
|
||||
try:
|
||||
tool_data = json.loads(tool_call_buffer)
|
||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||
function=ChatCompletionStreamOutputFunction(
|
||||
arguments=json.dumps(tool_data.get("arguments")),
|
||||
name=tool_data.get("name"),
|
||||
),
|
||||
id=tool_data.get("id"),
|
||||
index=0,
|
||||
type=tool_data.get("type"),
|
||||
)
|
||||
yield self.build_chunk("", _request_id, tool_calls=[tool])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse tool call: {e}")
|
||||
raise e
|
||||
inside_tool_call = False
|
||||
tool_call_buffer = ""
|
||||
else:
|
||||
tool_call_buffer += result
|
||||
|
||||
continue
|
||||
|
||||
yield self.build_chunk(result, _request_id)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
thread.join()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
# For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype,
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
def load_model_and_tokenizer(self, args: ServeArguments) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = self.get_quantization_config(args)
|
||||
|
||||
model_kwargs = {
|
||||
"revision": args.model_revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
|
||||
|
||||
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 256:
|
||||
model.generation_config.max_new_tokens = 256
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve = ServeCommand()
|
||||
serve.model_name_or_path = "Menlo/Jan-nano"
|
||||
serve.run()
|
||||
|
@ -54,7 +54,7 @@ deps = {
|
||||
"protobuf": "protobuf",
|
||||
"psutil": "psutil",
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic",
|
||||
"pydantic": "pydantic>=2",
|
||||
"pytest": "pytest>=7.2.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest-rerunfailures": "pytest-rerunfailures",
|
||||
|
@ -27,6 +27,8 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -72,6 +74,7 @@ class GenerationOutput:
|
||||
error: Optional[str] = None
|
||||
status: RequestStatus = RequestStatus.PENDING
|
||||
created_time: float = field(default_factory=time.time)
|
||||
next_token: Optional[int] = field(default_factory=int)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -96,6 +99,7 @@ class RequestState:
|
||||
eos_token_id: int = -1
|
||||
created_time: float = field(default_factory=time.time)
|
||||
error: Optional[str] = None
|
||||
next_token: Optional[str] = None
|
||||
|
||||
def current_len(self) -> int:
|
||||
"""Get the current length of the sequence (prompt + generated tokens)."""
|
||||
@ -139,6 +143,7 @@ class RequestState:
|
||||
generated_tokens=self.static_outputs,
|
||||
logprobs=[],
|
||||
error=self.error,
|
||||
next_token=self.next_token,
|
||||
)
|
||||
|
||||
|
||||
@ -764,6 +769,9 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.setup_static_tensors()
|
||||
|
||||
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced(standalone=True)
|
||||
def setup_static_tensors(self):
|
||||
T = self.max_batch_tokens
|
||||
@ -995,7 +1003,7 @@ class ContinuousBatchProcessor:
|
||||
def _maybe_send_output(self, state: RequestState, token: int):
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if self.streaming:
|
||||
state.next_token = token
|
||||
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
@ -1102,6 +1110,7 @@ class ContinuousBatchingManager:
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced
|
||||
def start(self):
|
||||
|
@ -291,6 +291,30 @@ except importlib.metadata.PackageNotFoundError:
|
||||
_essentia_version = False
|
||||
|
||||
|
||||
_pydantic_available = importlib.util.find_spec("pydantic") is not None
|
||||
try:
|
||||
_pydantic_version = importlib.metadata.version("pydantic")
|
||||
logger.debug(f"Successfully imported pydantic version {_pydantic_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_pydantic_available = False
|
||||
|
||||
|
||||
_fastapi_available = importlib.util.find_spec("fastapi") is not None
|
||||
try:
|
||||
_fastapi_version = importlib.metadata.version("fastapi")
|
||||
logger.debug(f"Successfully imported pydantic version {_fastapi_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_fastapi_available = False
|
||||
|
||||
|
||||
_uvicorn_available = importlib.util.find_spec("uvicorn") is not None
|
||||
try:
|
||||
_uvicorn_version = importlib.metadata.version("uvicorn")
|
||||
logger.debug(f"Successfully imported pydantic version {_uvicorn_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_uvicorn_available = False
|
||||
|
||||
|
||||
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
|
||||
try:
|
||||
_pretty_midi_version = importlib.metadata.version("pretty_midi")
|
||||
@ -472,6 +496,18 @@ def is_essentia_available():
|
||||
return _essentia_available
|
||||
|
||||
|
||||
def is_pydantic_available():
|
||||
return _pydantic_available
|
||||
|
||||
|
||||
def is_fastapi_available():
|
||||
return _fastapi_available
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _uvicorn_available
|
||||
|
||||
|
||||
def is_pretty_midi_available():
|
||||
return _pretty_midi_available
|
||||
|
||||
@ -1771,6 +1807,23 @@ VISION_IMPORT_ERROR = """
|
||||
`pip install pillow`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYDANTIC_IMPORT_ERROR = """
|
||||
{0} requires the pydantic library but it was not found in your environment. You can install it with pip:
|
||||
`pip install pydantic`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
FASTAPI_IMPORT_ERROR = """
|
||||
{0} requires the fastapi library but it was not found in your environment. You can install it with pip:
|
||||
`pip install fastapi`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
UVICORN_IMPORT_ERROR = """
|
||||
{0} requires the uvicorn library but it was not found in your environment. You can install it with pip:
|
||||
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTESSERACT_IMPORT_ERROR = """
|
||||
@ -1893,6 +1946,9 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
|
||||
("rich", (is_rich_available, RICH_IMPORT_ERROR)),
|
||||
("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)),
|
||||
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
|
||||
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
|
||||
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
65
tests/commands/test_chat.py
Normal file
65
tests/commands/test_chat.py
Normal file
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.chat import ChatArguments, ChatCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ChatCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "chat", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("chat interface", cs.out.lower())
|
||||
|
||||
@patch.object(ChatCommand, "run")
|
||||
def test_cli_dispatch(self, run_mock):
|
||||
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ChatCommand, "run") as run_mock,
|
||||
patch(
|
||||
"sys.argv",
|
||||
[
|
||||
"transformers",
|
||||
"chat",
|
||||
"test-model",
|
||||
"max_new_tokens=64",
|
||||
],
|
||||
),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.model_name_or_path_or_address, "test-model")
|
||||
self.assertEqual(parsed_args.generate_flags, ["max_new_tokens=64"])
|
||||
|
||||
|
||||
class ChatUtilitiesTest(unittest.TestCase):
|
||||
def test_save_and_clear_chat(self):
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
|
||||
args = ChatArguments(save_folder=str(tmp_path))
|
||||
args.model_name_or_path_or_address = "test-model"
|
||||
|
||||
chat_history = [{"role": "user", "content": "hi"}]
|
||||
filename = ChatCommand.save_chat(chat_history, args)
|
||||
self.assertTrue(os.path.isfile(filename))
|
||||
|
||||
cleared = ChatCommand.clear_chat_history()
|
||||
self.assertEqual(cleared, [])
|
||||
|
||||
def test_parse_generate_flags(self):
|
||||
dummy = ChatCommand.__new__(ChatCommand)
|
||||
parsed = ChatCommand.parse_generate_flags(dummy, ["temperature=0.5", "max_new_tokens=10"])
|
||||
self.assertEqual(parsed["temperature"], 0.5)
|
||||
self.assertEqual(parsed["max_new_tokens"], 10)
|
36
tests/commands/test_serving.py
Normal file
36
tests/commands/test_serving.py
Normal file
@ -0,0 +1,36 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ServeCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "serve", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("serve", cs.out.lower())
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ServeCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ServeCommand, "run") as run_mock,
|
||||
patch("sys.argv", ["transformers", "serve", "the-model", "--host", "0.0.0.0", "--port", "9000"]),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.model_name_or_path, "the-model")
|
||||
self.assertEqual(parsed_args.host, "0.0.0.0")
|
||||
self.assertEqual(parsed_args.port, 9000)
|
||||
|
||||
def test_build_chunk(self):
|
||||
dummy = ServeCommand.__new__(ServeCommand)
|
||||
dummy.args = type("Args", (), {"model_name_or_path": "test-model"})()
|
||||
chunk = ServeCommand.build_chunk(dummy, "hello", "req0", finish_reason="stop")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
self.assertIn("test-model", chunk)
|
Reference in New Issue
Block a user