Compare commits

...

21 Commits

Author SHA1 Message Date
825c8e5695 Initial tool handling 2025-06-24 12:05:46 +02:00
8735fae88b Fix tests 2025-06-24 11:01:05 +02:00
be99c3d651 Fix tests 2025-06-24 11:01:05 +02:00
a729324a73 Add tests 2025-06-24 11:01:05 +02:00
fa30ee577d CI errors 2025-06-24 11:01:05 +02:00
f8d7972e9d CI errors 2025-06-24 11:01:05 +02:00
61d7189903 Better error handling 2025-06-24 11:01:05 +02:00
2a675dc552 Better error handling 2025-06-24 11:01:05 +02:00
40fe9fa4bd Last comments on PR 2025-06-24 11:01:05 +02:00
89e4f912e8 Update 2025-06-24 11:01:05 +02:00
dfd6a7360e Lucain's comments
Co-authored-by: Lucain <lucain@huggingface.co>
2025-06-24 11:01:05 +02:00
d4ba83051f Update src/transformers/commands/serving.py
Co-authored-by: célina <hanouticelina@gmail.com>
2025-06-24 11:01:05 +02:00
d152e9e179 Finalize chat.py 2025-06-24 11:01:05 +02:00
154f4c144d Finalize serving.py
Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com>
2025-06-24 11:01:05 +02:00
f043c82818 temp 2025-06-24 11:01:05 +02:00
8219b66d2f temp 2025-06-24 11:01:05 +02:00
1b5172843d Generation Config 2025-06-24 11:01:05 +02:00
070aab3a09 Style 2025-06-24 11:01:05 +02:00
482fed4837 Support both generation methods 2025-06-24 11:01:05 +02:00
4173cfea98 Split chat and serve 2025-06-24 11:00:49 +02:00
f6024b723c Next token 2025-06-24 10:59:57 +02:00
8 changed files with 657 additions and 261 deletions

View File

@ -148,7 +148,7 @@ _deps = [
"protobuf",
"psutil",
"pyyaml>=5.1",
"pydantic",
"pydantic>=2",
"pytest>=7.2.0",
"pytest-asyncio",
"pytest-rerunfailures",

View File

@ -13,7 +13,7 @@
# limitations under the License.
import copy
import asyncio
import json
import os
import platform
@ -24,22 +24,20 @@ import warnings
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Optional
from typing import AsyncIterator, Optional
import yaml
from huggingface_hub.utils import disable_progress_bars
from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
from transformers import (
AutoTokenizer,
GenerationConfig,
PreTrainedTokenizer,
TextIteratorStreamer,
logging,
)
from transformers.commands import BaseTransformersCLICommand
from transformers.commands.serving import ServeArguments, ServeCommand
from transformers.utils import is_rich_available, is_torch_available
from . import BaseTransformersCLICommand
if platform.system() != "Windows":
import pwd
@ -52,8 +50,12 @@ if is_rich_available():
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
ALLOWED_VALUE_CHARS = set(
@ -133,21 +135,21 @@ class RichInterface:
else:
self.user_name = user_name
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
"""Stream output from a role, and return the generated text after it's done steaming."""
# This method is originally from the FastChat CLI:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
self._console.print(f"[bold blue]<{self.model_name}>:")
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for i, outputs in enumerate(output_stream):
if not outputs or i == 0:
text = ""
async for token in await stream:
outputs = token.choices[0].delta.content
request_id = token.id
if not outputs:
continue
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
@ -160,6 +162,7 @@ class RichInterface:
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
@ -169,11 +172,15 @@ class RichInterface:
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output
live.update(markdown)
live.update(markdown, refresh=True)
self._console.print()
return text
return text, request_id
def input(self) -> str:
"""Gets user input from the console."""
@ -300,6 +307,10 @@ class ChatArguments:
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
# Serving settings
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
def chat_command_factory(args: Namespace):
"""
@ -322,7 +333,10 @@ class ChatCommand(BaseTransformersCLICommand):
group = chat_parser.add_argument_group("Positional arguments")
group.add_argument(
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
"model_name_or_path_or_address",
type=str,
default=None,
help="Name of the pre-trained model or address to connect to.",
)
group.add_argument(
"generate_flags",
@ -340,6 +354,36 @@ class ChatCommand(BaseTransformersCLICommand):
def __init__(self, args):
args = self._handle_deprecated_args(args)
if args.model_name_or_path_or_address is not None:
name = args.model_name_or_path_or_address
if name.startswith("http") or name.startswith("https") or name.startswith("localhost"):
self.spawn_backend = False
if args.host != "localhost" or args.port != 8000:
raise ValueError(
"Looks like youve set both a server address and a custom host/port. "
"Please pick just one way to specify the server."
)
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
else:
self.spawn_backend = True
args.model_name_or_path = args.model_name_or_path_or_address
if not is_rich_available() and (not is_torch_available() and self.spawn_backend):
raise ImportError(
"You need to install rich to use the chat interface. Additionally, you have not specified a remote "
"endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)"
)
elif not is_rich_available():
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
elif not is_torch_available() and self.spawn_backend:
raise ImportError(
"You have not specified a remote endpoint and are therefore spawning a backend. Torch is required "
"for this: (`pip install rich torch`)"
)
self.args = args
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
@ -349,22 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
"""
has_warnings = False
# 1. Model as a positional argument
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
if args.model_name_or_path_positional is None:
raise ValueError(
"One of the following must be provided:"
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
)
elif args.model_name_or_path is not None:
has_warnings = True
warnings.warn(
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
"argument instead, e.g. `transformers chat <model_repo>`.",
FutureWarning,
)
# 2. Named generate option args
# Named generate option args
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
value = getattr(args, deprecated_arg)
if value != default_value:
@ -404,7 +433,7 @@ class ChatCommand(BaseTransformersCLICommand):
if filename is None:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
filename = f"{args.model_name_or_path_or_address}/chat_{time_str}.json"
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
@ -477,40 +506,20 @@ class ChatCommand(BaseTransformersCLICommand):
)
return processed_generate_flags
def get_generation_parameterization(
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
) -> tuple[GenerationConfig, dict]:
def get_generation_parameterization(self, args: ChatArguments) -> tuple[GenerationConfig, dict]:
"""
Returns a GenerationConfig object holding the generation parameters for the CLI command.
"""
# No generation config arg provided -> use default generation config, apply CLI defaults
if args.generation_config is None:
# We start off from the checkpoint's generation config
generation_config = copy.deepcopy(model.generation_config)
# Apply deprecated CLI args on top of the default generation config
pad_token_id, eos_token_ids = self.parse_eos_tokens(
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
)
deprecated_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
"num_beams": args.num_beams,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"pad_token_id": pad_token_id,
"eos_token_id": eos_token_ids,
}
generation_config.update(**deprecated_kwargs)
# generation config arg provided -> use it as the base parameterization
else:
# No generation config arg provided -> use base generation config, apply CLI defaults
if args.generation_config is not None:
if ".json" in args.generation_config: # is a local file
dirname = os.path.dirname(args.generation_config)
filename = os.path.basename(args.generation_config)
generation_config = GenerationConfig.from_pretrained(dirname, filename)
else:
generation_config = GenerationConfig.from_pretrained(args.generation_config)
else:
generation_config = GenerationConfig()
# Finally: parse and apply `generate_flags`
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
@ -664,7 +673,7 @@ class ChatCommand(BaseTransformersCLICommand):
elif user_input == "!status":
interface.print_status(
model_name=args.model_name_or_path_positional,
model_name=args.model_name_or_path,
generation_config=generation_config,
model_kwargs=model_kwargs,
)
@ -679,10 +688,33 @@ class ChatCommand(BaseTransformersCLICommand):
# -----------------------------------------------------------------------------------------------------------------
# Main logic
def run(self):
if not is_rich_available():
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
if not is_torch_available():
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)")
asyncio.run(self._inner_run())
async def _inner_run(self):
if self.spawn_backend:
serve_args = ServeArguments(
model_revision=self.args.model_revision,
device=self.args.device,
torch_dtype=self.args.torch_dtype,
trust_remote_code=self.args.trust_remote_code,
attn_implementation=self.args.attn_implementation,
load_in_8bit=self.args.load_in_8bit,
load_in_4bit=self.args.load_in_4bit,
bnb_4bit_quant_type=self.args.bnb_4bit_quant_type,
use_bnb_nested_quant=self.args.use_bnb_nested_quant,
host=self.args.host,
port=self.args.port,
log_level="error",
)
serve_args.model_name_or_path = self.args.model_name_or_path
serve_command = ServeCommand(serve_args)
thread = Thread(target=serve_command.run)
thread.daemon = True
thread.start()
host = "http://localhost" if self.args.host == "localhost" else self.args.host
client = AsyncInferenceClient(f"{host}:{self.args.port}")
args = self.args
if args.examples_path is None:
@ -696,19 +728,14 @@ class ChatCommand(BaseTransformersCLICommand):
else:
user = args.user
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
generation_config, model_kwargs = self.get_generation_parameterization(args)
# if not verbose -> disable warnings, progress bars, etc in the chat interface
if not args.verbose:
logging.set_verbosity_error()
disable_progress_bars()
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = self.clear_chat_history(args.system_prompt)
request_id = None
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
interface.print_help(minimal=True)
while True:
@ -736,23 +763,25 @@ class ChatCommand(BaseTransformersCLICommand):
else:
chat.append({"role": "user", "content": user_input})
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
stream = client.chat_completion(
chat,
stream=True,
extra_body={"request_id": request_id, "generation_config": {**generation_config.to_dict()}},
)
attention_mask = torch.ones_like(inputs)
generation_kwargs = {
"inputs": inputs,
"attention_mask": attention_mask,
"streamer": generation_streamer,
"generation_config": generation_config,
**model_kwargs,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
model_output = interface.stream_output(generation_streamer)
thread.join()
model_output, request_id = await interface.stream_output(stream)
chat.append({"role": "assistant", "content": model_output})
except KeyboardInterrupt:
break
finally:
await client.close()
if __name__ == "__main__":
args = ChatArguments()
args.model_name_or_path_or_address = "meta-llama/Llama-3.2-3b-Instruct"
args.model_name_or_path_or_address = "http://localhost:8000"
chat = ChatCommand(args)
chat.run()

View File

@ -12,32 +12,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional
from ..pipelines import Pipeline, get_supported_tasks, pipeline
from ..utils import logging
from huggingface_hub import ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputFunction
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..utils import is_torch_available, logging
from . import BaseTransformersCLICommand
try:
from fastapi import Body, FastAPI, HTTPException
from fastapi.routing import APIRoute
if is_torch_available():
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
PreTrainedModel,
)
if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available():
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from starlette.responses import JSONResponse
from uvicorn import run
_serve_dependencies_installed = True
except (ImportError, AttributeError):
BaseModel = object
class Message(BaseModel):
role: str
content: str
def Body(*x, **y):
pass
class ChatCompletionInput(BaseModel):
messages: list[Message]
_serve_dependencies_installed = False
stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None
extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[list[float]] = None
max_tokens: Optional[int] = None
stop: Optional[list[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
seed: Optional[int] = None
# Additional options supported by the HFH InferenceClient
# that aren't yet supported here.
# logprobs: Optional[bool] = None
tools: Any = None
# n: Optional[int] = None
# presence_penalty: Optional[float] = None
# response_format: Optional[ChatCompletionInputGrammarType] = None
# stream_options: Optional[ChatCompletionInputStreamOptions] = None
# tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
# tool_prompt: Optional[str] = None
# top_logprobs: Optional[int] = None
logger = logging.get_logger("transformers/serving")
logger = logging.get_logger(__name__)
def serve_command_factory(args: Namespace):
@ -46,47 +89,88 @@ def serve_command_factory(args: Namespace):
Returns: ServeCommand
"""
nlp = pipeline(
task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device,
return ServeCommand(args)
def create_generation_config_from_req(req: "ChatCompletionInput"):
if req.extra_body is not None and "generation_config" in req.extra_body:
for key in req.extra_body["generation_config"].keys():
if key in ChatCompletionInput.base_field_names.keys():
return {"error": "Duplicated key in the root request and in the passed generation config."}
if req.extra_body is not None and "generation_config" in req.extra_body:
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
else:
generation_config = GenerationConfig()
if req.frequency_penalty is not None:
generation_config.repetition_penalty = req.frequency_penalty
if req.logit_bias is not None:
generation_config.sequence_bias = req.logit_bias
if req.stop is not None:
generation_config.stop_strings = req.stop
if req.temperature is not None:
generation_config.temperature = req.temperature
if req.top_p is not None:
generation_config.top_p = req.top_p
if req.seed is not None:
torch.manual_seed(req.seed)
return generation_config
@dataclass
class ServeArguments:
r"""
Arguments for the serve CLI.
See the metadata arg for each argument's description -- the metadata will be printed with
`transformers serve --help`
"""
# Model loading
model_revision: str = field(
default="main",
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
)
return ServeCommand(nlp, args.host, args.port, args.workers)
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
torch_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
"the dtype will be automatically derived from the model's weights.",
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
)
attn_implementation: Optional[str] = field(
default=None,
metadata={
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
},
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
)
load_in_4bit: bool = field(
default=False,
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
)
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
# Serving settings
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
class ServeModelInfoResult(BaseModel):
"""
Expose model information
"""
infos: dict
class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
tokens: list[str]
tokens_ids: Optional[list[int]]
class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text: str
class ServeForwardResult(BaseModel):
"""
Forward result model
"""
output: Any
# Other settings
log_level: str = field(
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
)
class ServeCommand(BaseTransformersCLICommand):
@ -98,131 +182,248 @@ class ServeCommand(BaseTransformersCLICommand):
Args:
parser: Root parser to register command-specific arguments
"""
serve_parser = parser.add_parser(
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
)
serve_parser.add_argument(
"--task",
type=str,
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
dataclass_types = (ServeArguments,)
serve_parser = parser.add_parser("serve", dataclass_types=dataclass_types)
group = serve_parser.add_argument_group("Positional arguments")
group.add_argument("model_name_or_path", type=str, default=None, help="Name of the pre-trained model.")
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline
self.host = host
self.port = port
self.workers = workers
if not _serve_dependencies_installed:
raise RuntimeError(
"Using serve command requires FastAPI and uvicorn. "
'Please install transformers with [serving]: pip install "transformers[serving]". '
"Or install FastAPI and uvicorn separately."
)
else:
logger.info(f"Serving model over {host}:{port}")
self._app = FastAPI(
routes=[
APIRoute(
"/",
self.model_info,
response_model=ServeModelInfoResult,
response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
def __init__(self, args: ServeArguments):
if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available():
raise ImportError(
"Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
)
self.args = args
self.model, self.tokenizer = self.load_model_and_tokenizer(args)
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
transformers_logger = logging.get_logger("transformers")
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk(self, content: str, request_id: str, finish_reason: Optional[str] = None, tool_calls=None) -> str:
print(content)
print(tool_calls)
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.args.model_name_or_path,
"system_fingerprint": "",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": content, "tool_calls": tool_calls or []},
"logprobs": None,
"finish_reason": finish_reason,
}
],
}
print(payload)
print(f"data: {json.dumps(payload)}\n\n")
return f"data: {json.dumps(payload)}\n\n"
def run(self):
run(self._app, host=self.host, port=self.port, workers=self.workers)
app = FastAPI()
def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
if self.use_continuous_batching:
self.continuous_batching(app)
else:
self.generate(app)
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
"""
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
mapping.
"""
try:
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
if return_ids:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
else:
return ServeTokenizeResult(tokens=tokens_txt)
def continuous_batching(self, app):
generation_config = GenerationConfig(
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
manager.start()
def detokenize(
self,
tokens_ids: list[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
"""
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
@app.post("/v1/chat/completions")
def _serve(req: ChatCompletionInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
async def forward(self, inputs=Body(None, embed=True)):
"""
**inputs**: **attention_mask**: **tokens_type_ids**:
"""
chat = req.messages
# Check we don't have empty string
if len(inputs) == 0:
return ServeForwardResult(output=[], attention=[])
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
self.model.device
)
try:
# Forward through the model
output = self._pipeline(inputs)
return ServeForwardResult(output=output)
except Exception as e:
raise HTTPException(500, {"error": str(e)})
generation_config = create_generation_config_from_req(req)
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
queue_is_flushed = False
for result in manager:
if req.request_id is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
else:
queue_is_flushed = True
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
if result.status == RequestStatus.FINISHED:
break
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: ChatCompletionInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
}
def stream_response(streamer, _request_id):
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
try:
thread.start()
inside_tool_call = False
tool_call_buffer = ""
for result in streamer:
print("Chunk:", result)
if result.strip() == "<tool_call>":
inside_tool_call = True
tool_call_buffer = ""
continue
if inside_tool_call:
if result.strip() == "</tool_call>":
try:
tool_data = json.loads(tool_call_buffer)
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
arguments=json.dumps(tool_data.get("arguments")),
name=tool_data.get("name"),
),
id=tool_data.get("id"),
index=0,
type=tool_data.get("type"),
)
yield self.build_chunk("", _request_id, tool_calls=[tool])
except Exception as e:
logger.error(f"Failed to parse tool call: {e}")
raise e
inside_tool_call = False
tool_call_buffer = ""
else:
tool_call_buffer += result
continue
yield self.build_chunk(result, _request_id)
yield "data: [DONE]\n\n"
thread.join()
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
@staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
# For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_compute_dtype=model_args.torch_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.torch_dtype,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def load_model_and_tokenizer(self, args: ServeArguments) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = self.get_quantization_config(args)
model_kwargs = {
"revision": args.model_revision,
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
"quantization_config": quantization_config,
"trust_remote_code": args.trust_remote_code,
}
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 256:
model.generation_config.max_new_tokens = 256
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
if __name__ == "__main__":
serve = ServeCommand()
serve.model_name_or_path = "Menlo/Jan-nano"
serve.run()

View File

@ -54,7 +54,7 @@ deps = {
"protobuf": "protobuf",
"psutil": "psutil",
"pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic",
"pydantic": "pydantic>=2",
"pytest": "pytest>=7.2.0",
"pytest-asyncio": "pytest-asyncio",
"pytest-rerunfailures": "pytest-rerunfailures",

View File

@ -27,6 +27,8 @@ from typing import Optional, Union
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm
@ -72,6 +74,7 @@ class GenerationOutput:
error: Optional[str] = None
status: RequestStatus = RequestStatus.PENDING
created_time: float = field(default_factory=time.time)
next_token: Optional[int] = field(default_factory=int)
@dataclass
@ -96,6 +99,7 @@ class RequestState:
eos_token_id: int = -1
created_time: float = field(default_factory=time.time)
error: Optional[str] = None
next_token: Optional[str] = None
def current_len(self) -> int:
"""Get the current length of the sequence (prompt + generated tokens)."""
@ -139,6 +143,7 @@ class RequestState:
generated_tokens=self.static_outputs,
logprobs=[],
error=self.error,
next_token=self.next_token,
)
@ -764,6 +769,9 @@ class ContinuousBatchProcessor:
self.setup_static_tensors()
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced(standalone=True)
def setup_static_tensors(self):
T = self.max_batch_tokens
@ -995,7 +1003,7 @@ class ContinuousBatchProcessor:
def _maybe_send_output(self, state: RequestState, token: int):
"""Send output to the queue based on streaming mode and request state."""
if self.streaming:
state.next_token = token
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@ -1102,6 +1110,7 @@ class ContinuousBatchingManager:
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced
def start(self):

View File

@ -291,6 +291,30 @@ except importlib.metadata.PackageNotFoundError:
_essentia_version = False
_pydantic_available = importlib.util.find_spec("pydantic") is not None
try:
_pydantic_version = importlib.metadata.version("pydantic")
logger.debug(f"Successfully imported pydantic version {_pydantic_version}")
except importlib.metadata.PackageNotFoundError:
_pydantic_available = False
_fastapi_available = importlib.util.find_spec("fastapi") is not None
try:
_fastapi_version = importlib.metadata.version("fastapi")
logger.debug(f"Successfully imported pydantic version {_fastapi_version}")
except importlib.metadata.PackageNotFoundError:
_fastapi_available = False
_uvicorn_available = importlib.util.find_spec("uvicorn") is not None
try:
_uvicorn_version = importlib.metadata.version("uvicorn")
logger.debug(f"Successfully imported pydantic version {_uvicorn_version}")
except importlib.metadata.PackageNotFoundError:
_uvicorn_available = False
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try:
_pretty_midi_version = importlib.metadata.version("pretty_midi")
@ -472,6 +496,18 @@ def is_essentia_available():
return _essentia_available
def is_pydantic_available():
return _pydantic_available
def is_fastapi_available():
return _fastapi_available
def is_uvicorn_available():
return _uvicorn_available
def is_pretty_midi_available():
return _pretty_midi_available
@ -1771,6 +1807,23 @@ VISION_IMPORT_ERROR = """
`pip install pillow`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
PYDANTIC_IMPORT_ERROR = """
{0} requires the pydantic library but it was not found in your environment. You can install it with pip:
`pip install pydantic`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
FASTAPI_IMPORT_ERROR = """
{0} requires the fastapi library but it was not found in your environment. You can install it with pip:
`pip install fastapi`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
UVICORN_IMPORT_ERROR = """
{0} requires the uvicorn library but it was not found in your environment. You can install it with pip:
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
PYTESSERACT_IMPORT_ERROR = """
@ -1893,6 +1946,9 @@ BACKENDS_MAPPING = OrderedDict(
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
("rich", (is_rich_available, RICH_IMPORT_ERROR)),
("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)),
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
]
)

View File

@ -0,0 +1,65 @@
import os
import tempfile
import unittest
from unittest.mock import patch
import transformers.commands.transformers_cli as cli
from transformers.commands.chat import ChatArguments, ChatCommand
from transformers.testing_utils import CaptureStd
class ChatCLITest(unittest.TestCase):
def test_help(self):
with patch("sys.argv", ["transformers", "chat", "--help"]), CaptureStd() as cs:
with self.assertRaises(SystemExit):
cli.main()
self.assertIn("chat interface", cs.out.lower())
@patch.object(ChatCommand, "run")
def test_cli_dispatch(self, run_mock):
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
with patch("sys.argv", args):
cli.main()
run_mock.assert_called_once()
def test_parsed_args(self):
with (
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
patch.object(ChatCommand, "run") as run_mock,
patch(
"sys.argv",
[
"transformers",
"chat",
"test-model",
"max_new_tokens=64",
],
),
):
cli.main()
init_mock.assert_called_once()
run_mock.assert_called_once()
parsed_args = init_mock.call_args[0][0]
self.assertEqual(parsed_args.model_name_or_path_or_address, "test-model")
self.assertEqual(parsed_args.generate_flags, ["max_new_tokens=64"])
class ChatUtilitiesTest(unittest.TestCase):
def test_save_and_clear_chat(self):
tmp_path = tempfile.mkdtemp()
args = ChatArguments(save_folder=str(tmp_path))
args.model_name_or_path_or_address = "test-model"
chat_history = [{"role": "user", "content": "hi"}]
filename = ChatCommand.save_chat(chat_history, args)
self.assertTrue(os.path.isfile(filename))
cleared = ChatCommand.clear_chat_history()
self.assertEqual(cleared, [])
def test_parse_generate_flags(self):
dummy = ChatCommand.__new__(ChatCommand)
parsed = ChatCommand.parse_generate_flags(dummy, ["temperature=0.5", "max_new_tokens=10"])
self.assertEqual(parsed["temperature"], 0.5)
self.assertEqual(parsed["max_new_tokens"], 10)

View File

@ -0,0 +1,36 @@
import unittest
from unittest.mock import patch
import transformers.commands.transformers_cli as cli
from transformers.commands.serving import ServeCommand
from transformers.testing_utils import CaptureStd
class ServeCLITest(unittest.TestCase):
def test_help(self):
with patch("sys.argv", ["transformers", "serve", "--help"]), CaptureStd() as cs:
with self.assertRaises(SystemExit):
cli.main()
self.assertIn("serve", cs.out.lower())
def test_parsed_args(self):
with (
patch.object(ServeCommand, "__init__", return_value=None) as init_mock,
patch.object(ServeCommand, "run") as run_mock,
patch("sys.argv", ["transformers", "serve", "the-model", "--host", "0.0.0.0", "--port", "9000"]),
):
cli.main()
init_mock.assert_called_once()
run_mock.assert_called_once()
parsed_args = init_mock.call_args[0][0]
self.assertEqual(parsed_args.model_name_or_path, "the-model")
self.assertEqual(parsed_args.host, "0.0.0.0")
self.assertEqual(parsed_args.port, 9000)
def test_build_chunk(self):
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {"model_name_or_path": "test-model"})()
chunk = ServeCommand.build_chunk(dummy, "hello", "req0", finish_reason="stop")
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)
self.assertIn("test-model", chunk)