Compare commits

..

15 Commits

Author SHA1 Message Date
b734e7c35e style 2025-10-21 16:16:27 +02:00
ccbd1eceb3 fix 2025-10-21 16:16:10 +02:00
b68b48ce88 Merge remote-tracking branch 'upstream/main' into serve-quantization 2025-10-21 16:14:04 +02:00
72d8e7bb3c Merge remote-tracking branch 'origin/main' into serve-quantization 2025-10-15 12:19:48 +00:00
747fcfa227 rm check for now 2025-10-01 16:44:42 +00:00
a6506fa478 style 2025-10-01 16:28:58 +00:00
72ffb3d1d2 fix 2025-10-01 16:28:29 +00:00
f525309408 minor doc fix 2025-10-01 15:54:05 +00:00
ffa68ba7b8 fix args 2025-10-01 15:48:21 +00:00
eab734d23c fix 2025-10-01 15:10:55 +00:00
b604f62b6b Merge remote-tracking branch 'origin/main' into serve-quantization 2025-10-01 13:32:42 +00:00
35fff29efd fix 2025-10-01 13:31:05 +00:00
1cdd0bf0fb fix 2025-10-01 13:30:56 +00:00
907f206a1b fix api 2025-10-01 13:25:59 +00:00
86ba65350b doc 2025-10-01 12:47:38 +00:00
42 changed files with 993 additions and 1478 deletions

View File

@ -88,8 +88,6 @@
title: Tool use
- local: chat_templating_writing
title: Writing a chat template
- local: chat_response_parsing
title: Response parsing
title: Chat with models
- sections:
- local: serving

View File

@ -95,12 +95,9 @@ print(tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):]))
The chat model called the `get_current_temperature` tool with the correct parameters from the docstring. It inferred France as the location based on Paris, and that it should use Celsius for the units of temperature.
A model **cannot actually call the tool itself**. It requests a tool call, and it's your job to handle the call and append it and the result to the chat history. For
models that support [response parsing](./chat_response_parsing), the response parsing will be handled automatically, and you can just use
[`~PreTrainedTokenizer.parse_response] to extract the tool call. For other models, you'll need to manually translate the output
string into a tool call dict.
A model **cannot actually call the tool itself**. It requests a tool call, and it's your job to handle the call and append it and the result to the chat history.
Regardless of the approach you use, the tool call should go in the `tool_calls` key of an `assistant` message. This is the recommended API, and should be supported by the chat template of most tool-using models.
Hold the call in the `tool_calls` key of an `assistant` message. This is the recommended API, and should be supported by the chat template of most tool-using models.
> [!WARNING]
> Although `tool_calls` is similar to the OpenAI API, the OpenAI API uses a JSON string as its `tool_calls` format. This may cause errors or strange model behavior if used in Transformers, which expects a dict.

View File

@ -1,233 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Response Parsing
It is increasingly common for chat models to generate structured outputs, rather than just a single reply string.
The most common uses for structured outputs are [tool calling](./chat_extras) and [reasoning models](https://huggingface.co/reasoning-course).
Tool calling models can output tool calls, containing the name of the tool to call and any arguments to be passed to it,
while reasoning models often output reasoning steps as a "chain of thought". Some recent models even use both of these,
and may output reasoning and/or one or more tool calls before their final answer.
Models with structured outputs pose a challenge for chat templating, because the output needs to be parsed before it
can be appended to the chat. For a concrete example, let's say we ask [GPT-OSS](https://huggingface.co/openai/gpt-oss-120b)
what the weather is like, and it thinks and decides to call a tool. Here's what the raw model output might look like:
```txt
<|start|>analysis<|message|>The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco).
So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data.
So we should call get_current_weather with location "San Francisco, CA". Let's do that.
We will call function get_current_weather.<|end|><|start|>commentary to=functions.get_current_weather<|channel|>commentary <|constrain|>json<|message|>{"location":"San Francisco, CA"}<|call|>
}
```
But if you want to append this to a chat, you'll need to format it as a chat message dict, like this:
```json
{
"role": "assistant",
"thinking": "The user asks: \"What is the weather like in SF?\" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data. So we should call get_current_weather with location \"San Francisco, CA\". Let's do that.",
"tool_calls": [
{
"name": "get_current_weather",
"arguments": {
"location": "San Francisco, CA"
}
}
]
}
```
Chat **templates** give us a way to turn messages into formatted input for a model, but we need something else to
parse model output back into a standard message dict. This is what chat **parsing** is for.
## The [parse_response](~PreTrainedTokenizerBase.parse_response) method
Parsing a chat response on a model that supports it is straightforward. Simply take the raw, decoded output from
[generate](`~generation.GenerationMixin.generate`), and pass it to the tokenizer's [parse_response](~PreTrainedTokenizerBase.parse_response) method:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype="auto", device_map="auto")
messages = [
{
"role": "user",
"content": "Hey! Can you summarize the end of the Cold War as briefly as possible? Like, comically briefly. It should really leave out almost most of the relevant information."
}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(input_ids, max_new_tokens=1024)[0, input_ids.shape[1]:]
out_text = tokenizer.decode(outputs)
parsed = tokenizer.parse_response(out_text)
print(parsed.keys())
```
And you should get:
```text
dict_keys(['thinking', 'content'])
```
And that's all you need to start using response parsing! `parse_response` should return a complete message dict that is ready to be appended to the chat history.
When the tokenizer does not support response parsing, `parse_response` will throw an error. We hope to add support
to more tokenizers over time.
## Developers: Understanding a simple response schema
Under the hood, `parse_response` uses a **JSON schema** to parse the model output. A JSON schema represents
the structure of the output message dict. The schema is augmented with additional fields that indicate how the
output message string should be parsed into the expected format. Let's take a look at the schema for a SmolLM response,
excluding tool calls for now:
```python
{
"x-regex": "(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string"},
"thinking": {"type": "string"}
}
}
```
We can see that the schema describes a JSON "object" (a `dict`, in other words) with three keys: `role`, `content`, and `thinking`.
Because all assistant responses have the role "assistant", the `role` key is a `const`(ant). The other two keys are strings, extracted
from the named groups in the regex in the `x-regex` field.
Like chat templates, response schemas are set as a property of the tokenizer. To enable response parsing, all you need
to do is set `tokenizer.response_schema` to a valid schema dict, and `tokenizer.parse_response()` will work! Again, like
chat templates, this schema will be saved with the processor, so once you set it, you can use `save_pretrained()` or `push_to_hub()` to
save and share the schema.
## Developers: Complex schemas
Now, let's look at a more complex schema, which includes tool calls, to gain more of an understanding of the parser
internals. For this, we'll use the `GPT-OSS` schema. GPT-OSS emits both tool calls and thinking blocks, and it uses
an unusual format where model responses are tagged with one of three "channels": `commentary` for things like
tool calls, `analysis` for chain of thought blocks, and `final` for messages intended to be sent to the user.
A full message where the model calls a tool named `get_current_weather` might look like this, with some extra linebreaks added for clarity:
```text
<|channel|>analysis<|message|>
The user asks: "What is the weather like in SF?" So we need to get the current weather in San Francisco, CA.
We need to call get_current_weather function. So we should call get_current_weather with location "San Francisco, CA".
<|end|>
<|start|>assistant<|channel|>commentary
to=functions.get_current_weather <|constrain|>json<|message|>
{
"location": "San Francisco, CA"
}
<|call|>
```
Parsing proceeds recursively; the output of a regex (or other parser) at one level becomes the input to the nodes below it.
In other words, don't feel like you have to parse the entire output in one enormous regex! Instead, start with the schema,
and then add regexes to extract the relevant chunks as you go. Here's a schema that will parse it, with some
explanatory comments:
```python
{
"type": "object",
"properties": {
"role": {"const": "assistant"},
# "content" and "thinking" are both similar to the previous example, and just extract a single string
# However, rather than using a single regex with named groups to extract both, we use a regex in each subkey.
# When an object node has no parser/regex, the entire input string is passed to all of its children, so
# parsing can either be done with named groups at the object level, or with separate regexes at the property level.
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
"tool_calls": {
# "x-regex-iterator" uses re.findall to find multiple possible manages, and returns them as an
# array/list. You don't need to worry about array handling, though - each item in the array will be
# parsed by the `items` schema, so just write the schema for a single item.
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
"type": "array",
"items": {
"type": "object",
"properties": {
# A const property is a fixed value, and the input has no effect on it.
"type": {"const": "function"},
# Here, we wrap the entire tool call dict in a `{"function": ...}` block. The input string is passed through to it unchanged.
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
"arguments": {
"type": "object",
"x-regex": "<\|message\|>(.*)",
# The "x-parser" field indicates that the extracted string should be parsed as JSON.
# The output is then passed to the schema nodes below and recursive parsing continues.
"x-parser": "json",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
```
## Developers: Understanding the parser logic
The parser follows a few simple rules:
1. Each level of the schema receives input from the level above, applies any regex or parser it has, and then passes the output to its children.
2. The root level receives the entire decoded model output string as input.
3. If a node has structured content after parsing (for example, if the regex has named groups and returns a dict, or if the parser returns a dict or list),
then that structured content is mapped to the node's children, and each child node receives its corresponding value as input.
4. If an `object` (dict) node has unstructured (string) output, then the entire string is passed to all of its children. This allows child nodes
to handle parsing individually rather than requiring a single parent regex to extract all keys at once.
5. If an `array` (list) node has unstructured (string) output, then this throws an error.
There is a small set of allowable `x-` keys that indicate how parsing should be done at each node:
- `x-regex`: A regex string to apply to the input. If the regex has named groups, the output is a dict of group names to values. Named groups should only be used in `object` nodes.
Otherwise, the regex must have exactly one unnamed capturing group, and the output is the value of that group as a string.
- `x-regex-iterator`: A regex string to apply to the input using `re.findall()`. The output is a list of all matches.
This should only be used in `array` nodes, and the regex must have exactly one unnamed capturing group. The output is distributed to
the node's `items` schema.
- `x-parser`: Calls a built-in parser to apply to the input. Currently, the only supported parser is `json`, which parses the input string as JSON.
The output is passed to the child nodes for further parsing. Note that the `json` parser can return deeply nested output - in this case, the output
will be progressively unwrapped as it is passed through child nodes. The child nodes do not need additional `x-parser` or `x-regex` fields in this case,
but their structure must match the structure of the parsed JSON.
- `x-parser-args`: Only allowed in conjunction with `x-parser`. This is a dict of additional arguments that control parsing. Right now, the only supported
argument is `transform`, which specifies a `jmespath` transformation to apply to the output. This is useful when the JSON parser returns a structure
that needs to be modified to match the schema.
- `x-regex-key-value`: This is rarely necessary, but it can be useful when parsing key-value pairs in non-JSON format where the names of the keys are not known
in advance, such as when a model emits XML tool calls with arbitrary argument names. The regex must have exactly two named capturing groups,
`key` and `value`, and the output is a dict mapping keys to values. This should only be used in `object` nodes.
In general, multiple regexes/parsers cannot be combined at the same level. The exception is that `x-regex`, returning a single string, can be combined with the other parsers. In this case,
`x-regex` is applied first, and then the output is passed to the other parser, either `x-regex-iterator`, `x-parser`, or `x-regex-key-value`.
Putting these ideas together, you can see that the input flows through the schema, being parsed at each level and then distributed to child nodes. Each level
only needs to extract the input content that is relevant for that part of the schema, and can then let its child nodes handle the rest. Internally, this is handled
with a parser function that receives input, applies any regexes/parsers at the current level, then maps the result to its child nodes before recursively calling itself on each of them.
Recursion terminates when it reaches leaf nodes, usually primitive types like `string` or `number`, which simply return the input they receive.

View File

@ -147,13 +147,6 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- post_process_keypoint_matching
- visualize_keypoint_matching
## LightGlueImageProcessorFast
[[autodoc]] LightGlueImageProcessorFast
- preprocess
- post_process_keypoint_matching
- visualize_keypoint_matching
## LightGlueForKeypointMatching
[[autodoc]] LightGlueForKeypointMatching

View File

@ -383,6 +383,30 @@ transformers serve \
--attn_implementation "sdpa"
```
### Quantization
transformers serve is compatible with all [quantization methods](https://huggingface.co/docs/transformers/main/quantization/overview) supported in transformers. Quantization can significantly reduce memory usage and improve inference speed, with two main workflows: pre-quantized models and on-the-fly quantization.
#### Pre-quantized Models
For models that are already quantized (e.g., GPTQ, AWQ, bitsandbytes), simply choose a quantized model name for serving.
Make sure to install the required libraries listed in the quantization documentation.
> [!TIP]
> Pre-quantized models generally provide the best balance of performance and accuracy.
#### On the fly quantization
If you want to quantize a model at runtime, you can specify the --quantization flag in the CLI. Note that not all quantization methods support on-the-fly conversion. The full list of supported methods is available in the quantization [overview](https://huggingface.co/docs/transformers/main/quantization/overview).
Currently, with transformers serve, we only supports the some methods: ["bitsandbytes-4bit", "bitsandbytes-8bit"]
For example, to enable 4-bit quantization with bitsandbytes, you need to pass add `--quantization bitsandbytes-4bit`:
```sh
transformers serve --quantization bitsandbytes-4bit
```
### Performance tips
- Use an efficient attention backend when available:
@ -397,6 +421,4 @@ transformers serve \
- `--dtype {bfloat16|float16}` typically improve throughput and memory use vs. `float32`
- `--load_in_4bit`/`--load_in_8bit` can reduce memory footprint for LoRA setups
- `--force-model <repo_id>` avoids per-request model hints and helps produce stable, repeatable runs

View File

@ -125,23 +125,15 @@ def token_type_ids_mask_function(
# If it's 1 for both query and key/value, we are in an image block
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

View File

@ -117,7 +117,6 @@ _deps = [
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"jinja2>=3.1.0",
"jmespath>=1.0.1",
"kenlm",
"kernels>=0.10.2,<0.11",
"librosa",
@ -295,7 +294,7 @@ extras["num2words"] = deps_list("num2words")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["mistral-common"] = deps_list("mistral-common[opencv]")
extras["chat_template"] = deps_list("jinja2", "jmespath")
extras["chat_template"] = deps_list("jinja2")
extras["testing"] = (
deps_list(
"pytest",

View File

@ -377,14 +377,10 @@ class Serve:
help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
),
] = None,
load_in_8bit: Annotated[
bool, typer.Option(help="Whether to use 8 bit precision for the base model - works only with LoRA.")
] = False,
load_in_4bit: Annotated[
bool, typer.Option(help="Whether to use 4 bit precision for the base model - works only with LoRA.")
] = False,
bnb_4bit_quant_type: Annotated[str, typer.Option(help="Quantization type.")] = "nf4",
use_bnb_nested_quant: Annotated[bool, typer.Option(help="Whether to use nested quantization.")] = False,
quantization: Annotated[
Optional[str],
typer.Option(help="Which quantization method to use. choices: 'bitsandbytes-4bit', 'bitsandbytes-8bit'"),
] = None,
host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost",
port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000,
model_timeout: Annotated[
@ -424,10 +420,7 @@ class Serve:
self.dtype = dtype
self.trust_remote_code = trust_remote_code
self.attn_implementation = attn_implementation
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.use_bnb_nested_quant = use_bnb_nested_quant
self.quantization = quantization
self.host = host
self.port = port
self.model_timeout = model_timeout
@ -1688,19 +1681,14 @@ class Serve:
Returns:
`Optional[BitsAndBytesConfig]`: The quantization config.
"""
if self.load_in_4bit:
if self.quantization == "bitsandbytes-4bit":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
# For consistency with model weights, we use the same value as `dtype`
bnb_4bit_compute_dtype=self.dtype,
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=self.use_bnb_nested_quant,
bnb_4bit_quant_storage=self.dtype,
)
elif self.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
elif self.quantization == "bitsandbytes-8bit":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
quantization_config = None
@ -1750,7 +1738,6 @@ class Serve:
revision=revision,
trust_remote_code=self.trust_remote_code,
)
dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
quantization_config = self.get_quantization_config()
@ -1758,19 +1745,15 @@ class Serve:
"revision": revision,
"attn_implementation": self.attn_implementation,
"dtype": dtype,
"device_map": "auto",
"device_map": self.device,
"trust_remote_code": self.trust_remote_code,
"quantization_config": quantization_config,
}
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_kwargs)
if getattr(model, "hf_device_map", None) is None:
model = model.to(self.device)
has_default_max_length = (
model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
)

View File

@ -27,7 +27,6 @@ deps = {
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"jinja2": "jinja2>=3.1.0",
"jmespath": "jmespath>=1.0.1",
"kenlm": "kenlm",
"kernels": "kernels>=0.10.2,<0.11",
"librosa": "librosa",

View File

@ -27,6 +27,7 @@ from ...utils.metrics import traced
logger = logging.getLogger("ContinuousBatchingLogger")
@staticmethod
def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
if torch.cuda.is_available():
device = torch.device("cuda")

View File

@ -164,7 +164,6 @@ except ImportError:
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
"mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "clean-mamba-ssm"},
}
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
@ -236,7 +235,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
return mapping[kernel_name]
if kernel_name not in _HUB_KERNEL_MAPPING:
logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
mapping[kernel_name] = None
return None
if _kernels_available:
@ -244,9 +243,8 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]
try:
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
kernel = get_kernel(repo_id, revision=revision, version=version)
kernel = get_kernel(repo_id, version=version)
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None

View File

View File

@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .selective_scan_with_ln_interface import mamba_inner_fn

View File

@ -0,0 +1,525 @@
# coding=utf-8
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import causal_conv1d_cuda
except ImportError:
causal_conv1d_cuda = None
import mamba_ssm
import selective_scan_cuda
# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
if hasattr(mamba_ssm.ops.triton, "layernorm"):
from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
else:
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(
ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout,
x,
out,
None,
ctx.delta_softplus,
False, # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (
du,
ddelta,
dA,
dB,
dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None,
)
def rms_norm_forward(
x,
weight,
bias,
eps=1e-6,
is_rms_norm=True,
):
# x (b l) d
if x.stride(-1) != 1:
x = x.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
# y (b l) d
return y
def selective_scan_fn(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_ref(
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
"""
xz: (batch, dim, seqlen)
"""
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
assert checkpoint_lvl in [0, 1]
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
out_proj_bias = (
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
)
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
if b_rms_weight is not None:
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
ctx.b_rms_weight = b_rms_weight
ctx.c_rms_weight = c_rms_weight
ctx.dt_rms_weight = dt_rms_weight
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.save_for_backward(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
)
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
out,
) = ctx.saved_tensors
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
if dout.stride(-1) != 1:
dout = dout.contiguous()
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
if b_rms_weight is not None:
# Recompute & RMSNorm B
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
# Recompute & RMSNorm C
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
dx, dz = dxz.chunk(2, dim=1)
dout = rearrange(dout, "b l e -> e (b l)")
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
conv1d_out,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout_y,
scan_intermediates,
out,
dz,
ctx.delta_softplus,
True, # option to recompute out_z
)
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (
dxz,
dconv1d_weight,
dconv1d_bias,
dx_proj_weight,
ddelta_proj_weight,
dout_proj_weight,
dout_proj_bias,
dA,
dB,
dC,
dD,
ddelta_bias if delta_bias is not None else None,
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
dB_proj_bias,
dC_proj_bias,
None,
None,
None,
None,
None,
None,
)
def mamba_inner_fn(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
b_rms_weight=None,
c_rms_weight=None,
dt_rms_weight=None,
b_c_dt_rms_eps=1e-6,
):
return MambaInnerFn.apply(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B,
C,
D,
delta_bias,
B_proj_bias,
C_proj_bias,
delta_softplus,
checkpoint_lvl,
b_rms_weight,
c_rms_weight,
dt_rms_weight,
b_c_dt_rms_eps,
)

View File

@ -121,7 +121,7 @@ else:
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")),
("lightglue", ("LightGlueImageProcessor", None)),
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),

View File

@ -1318,7 +1318,7 @@ class BarkFineModel(BarkPreTrainedModel):
output sound according to specific predefined voice.
"""
)
class BarkModel(BarkPreTrainedModel, GenerationMixin):
class BarkModel(BarkPreTrainedModel):
config: BarkConfig
def __init__(self, config):

View File

@ -34,7 +34,10 @@ from ...integrations.hub_kernels import lazy_load_kernel
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import is_mambapy_available
from ...utils.import_utils import (
is_mamba_ssm_available,
is_mambapy_available,
)
from .configuration_falcon_mamba import FalconMambaConfig
@ -43,6 +46,14 @@ if is_mambapy_available():
else:
pscan = None
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from ...kernels.falcon_mamba import mamba_inner_fn
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
logger = logging.get_logger(__name__)
@ -235,12 +246,6 @@ class FalconMambaMixer(nn.Module):
if causal_conv1d is not None
else (None, None)
)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
if mamba_ssm is not None
else (None, None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
@ -272,12 +277,7 @@ class FalconMambaMixer(nn.Module):
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
mamba_ssm.selective_state_update,
mamba_ssm.selective_scan_fn,
mamba_ssm.mamba_inner_fn,
)
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
@ -506,16 +506,6 @@ class FalconMambaMixer(nn.Module):
if causal_conv1d is not None
else (None, None)
)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
(
mamba_ssm.selective_state_update,
mamba_ssm.selective_scan_fn,
mamba_ssm.mamba_inner_fn,
)
if mamba_ssm is not None
else (None, None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

View File

@ -22,6 +22,7 @@ from torch import nn
from ...integrations.hub_kernels import lazy_load_kernel
from ...utils import auto_docstring, logging
from ...utils.import_utils import (
is_mamba_ssm_available,
is_mambapy_available,
)
from ..mamba.configuration_mamba import MambaConfig
@ -45,6 +46,14 @@ if is_mambapy_available():
else:
pscan = None
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from ...kernels.falcon_mamba import mamba_inner_fn
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
class FalconMambaConfig(MambaConfig):
"""
@ -251,12 +260,6 @@ class FalconMambaMixer(MambaMixer):
if causal_conv1d is not None
else (None, None)
)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
if mamba_ssm is not None
else (None, None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
@ -299,12 +302,7 @@ class FalconMambaMixer(MambaMixer):
):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
mamba_ssm.selective_state_update,
mamba_ssm.selective_scan_fn,
mamba_ssm.mamba_inner_fn,
)
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
@ -532,16 +530,6 @@ class FalconMambaMixer(MambaMixer):
if causal_conv1d is not None
else (None, None)
)
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update, selective_scan_fn, mamba_inner_fn = (
(
mamba_ssm.selective_state_update,
mamba_ssm.selective_scan_fn,
mamba_ssm.mamba_inner_fn,
)
if mamba_ssm is not None
else (None, None, None)
)
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)

View File

@ -447,7 +447,7 @@ def convert_transformer_weights(
return zip([], [])
else:
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
elif f"{_TRANSFORMER_EMBEDDER}/mm_" in path:
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
if not _INCLUDE_VISION_ENCODER.value:
return zip([], [])
@ -553,7 +553,7 @@ def convert(
continue
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
update_tree(f"model.{path}", weights, config.vision_config.dtype)
update_tree(path, weights, config.vision_config.dtype)
else:
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
if not _INCLUDE_VISION_ENCODER.value:

View File

@ -768,23 +768,15 @@ def token_type_ids_mask_function(
# If it's 1 for both query and key/value, we are in an image block
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

View File

@ -20,7 +20,6 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_lightglue import *
from .image_processing_lightglue import *
from .image_processing_lightglue_fast import *
from .modeling_lightglue import *
else:
import sys

View File

@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Union
import numpy as np
@ -39,28 +40,20 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import TensorType, logging, requires_backends
from ...utils import TensorType, is_matplotlib_available, logging, requires_backends
from ...utils.import_utils import requires
from .modeling_lightglue import LightGlueKeypointMatchingOutput
if is_vision_available():
import PIL
from PIL import Image, ImageDraw
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class LightGlueImageProcessorKwargs(ImagesKwargs, total=False):
r"""
do_grayscale (`bool`, *optional*, defaults to `True`):
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
"""
do_grayscale: bool
def is_grayscale(
image: np.ndarray,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
@ -468,5 +461,60 @@ class LightGlueImageProcessor(BaseImageProcessor):
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:
raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
plot_image[:height0, :width0] = image_pair[0] / 255.0
plot_image[:height1, width0:] = image_pair[1] / 255.0
plt.imshow(plot_image)
plt.axis("off")
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
plt.plot(
[keypoint0_x, keypoint1_x + width0],
[keypoint0_y, keypoint1_y],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
plt.show()
__all__ = ["LightGlueImageProcessor"]

View File

@ -1,302 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_lightglue.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from torchvision.transforms.v2 import functional as F
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
ImageInput,
ImageType,
PILImageResampling,
SizeDict,
get_image_type,
is_pil_image,
is_valid_image,
is_vision_available,
to_numpy_array,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring
from .image_processing_lightglue import LightGlueImageProcessorKwargs
from .modeling_lightglue import LightGlueKeypointMatchingOutput
if is_vision_available():
from PIL import Image, ImageDraw
def _is_valid_image(image):
return is_pil_image(image) or (
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
)
def flatten_pair_images(images):
# Handle the pair validation and flattening similar to slow processor
if isinstance(images, list):
if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
# Single pair of images - keep as is, they'll be processed by the base class
return images
elif all(
isinstance(image_pair, list)
and len(image_pair) == 2
and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
for image_pair in images
):
# Multiple pairs - flatten them
images = [image for image_pair in images for image in image_pair]
return images
raise ValueError(
"Input images must be a one of the following :",
" - A pair of PIL images.",
" - A pair of 3D arrays.",
" - A list of pairs of PIL images.",
" - A list of pairs of 3D arrays.",
)
def is_grayscale(
image: "torch.Tensor",
):
"""Checks if an image is grayscale (all RGB channels are identical)."""
if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
return True
return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
image[..., 1, :, :] == image[..., 2, :, :]
)
def convert_to_grayscale(
image: "torch.Tensor",
) -> "torch.Tensor":
"""
Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
channel, because of an issue that is discussed in :
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
Args:
image (torch.Tensor):
The image to convert.
"""
if is_grayscale(image):
return image
return F.rgb_to_grayscale(image, num_output_channels=3)
@auto_docstring
class LightGlueImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
size = {"height": 480, "width": 640}
default_to_square = False
do_resize = True
do_rescale = True
rescale_factor = 1 / 255
do_normalize = None
valid_kwargs = LightGlueImageProcessorKwargs
def __init__(self, **kwargs: Unpack[LightGlueImageProcessorKwargs]):
super().__init__(**kwargs)
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[LightGlueImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
**kwargs,
) -> ImageInput:
# we need to handle image pairs validation and flattening
return flatten_pair_images(images)
def _preprocess(
self,
images: list["torch.Tensor"],
size: Union[dict[str, int], SizeDict],
rescale_factor: float,
do_rescale: bool,
do_resize: bool,
interpolation: Optional["F.InterpolationMode"],
do_grayscale: bool,
disable_grouping: bool,
return_tensors: Union[str, TensorType],
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
processed_images_grouped[shape] = stacked_images
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:
stacked_images = self.rescale(stacked_images, rescale_factor)
if do_grayscale:
stacked_images = convert_to_grayscale(stacked_images)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Convert back to pairs format
image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]
# Stack each pair into a single tensor to match slow processor format
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
# Return in same format as slow processor
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
return BatchFeature(data={"pixel_values": image_pairs})
def post_process_keypoint_matching(
self,
outputs: LightGlueKeypointMatchingOutput,
target_sizes: Union[TensorType, list[tuple]],
threshold: float = 0.0,
) -> list[dict[str, torch.Tensor]]:
"""
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
with coordinates absolute to the original image sizes.
Args:
outputs ([`KeypointMatchingOutput`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
target size `(height, width)` of each image in the batch. This must be the original image size (before
any processing).
threshold (`float`, *optional*, defaults to 0.0):
Threshold to filter out the matches with low scores.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
of the pair, the matching scores and the matching indices.
"""
if outputs.matches.shape[0] != len(target_sizes):
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
if not all(len(target_size) == 2 for target_size in target_sizes):
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
if isinstance(target_sizes, list):
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
else:
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
raise ValueError(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
image_pair_sizes = target_sizes
keypoints = outputs.keypoints.clone()
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
keypoints = keypoints.to(torch.int32)
results = []
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
# Filter out matches with low scores
valid_matches = torch.logical_and(scores > threshold, matches > -1)
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
matching_scores = scores[0][valid_matches[0]]
results.append(
{
"keypoints0": matched_keypoints0,
"keypoints1": matched_keypoints1,
"matching_scores": matching_scores,
}
)
return results
def visualize_keypoint_matching(
self,
images,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images:
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
from .image_processing_lightglue import validate_and_format_image_pairs
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])
plot_image_pil = Image.fromarray(plot_image.numpy())
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return r, g, b
__all__ = ["LightGlueImageProcessorFast"]

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional, Union
@ -21,24 +22,25 @@ from torch import nn
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import PreTrainedConfig
from ...image_utils import ImageInput, is_vision_available, to_numpy_array
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TensorType, auto_docstring, logging
from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging
from ...utils.generic import can_return_tuple
from ..auto import CONFIG_MAPPING, AutoConfig
from ..auto.modeling_auto import AutoModelForKeypointDetection
from ..clip.modeling_clip import CLIPMLP
from ..cohere.modeling_cohere import apply_rotary_pos_emb
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
from ..superglue.image_processing_superglue import (
SuperGlueImageProcessor,
SuperGlueImageProcessorKwargs,
)
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs
from ..superpoint import SuperPointConfig
if is_vision_available():
from PIL import Image, ImageDraw
logger = logging.get_logger(__name__)
@ -215,10 +217,6 @@ class LightGlueKeypointMatchingOutput(ModelOutput):
attentions: Optional[tuple[torch.FloatTensor]] = None
class LightGlueImageProcessorKwargs(SuperGlueImageProcessorKwargs):
pass
class LightGlueImageProcessor(SuperGlueImageProcessor):
def post_process_keypoint_matching(
self,
@ -228,15 +226,123 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
) -> list[dict[str, torch.Tensor]]:
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
class LightGlueImageProcessorFast(SuperGlueImageProcessorFast):
def post_process_keypoint_matching(
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
def visualize_keypoint_matching(
self,
outputs: LightGlueKeypointMatchingOutput,
target_sizes: Union[TensorType, list[tuple]],
threshold: float = 0.0,
) -> list[dict[str, torch.Tensor]]:
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:
raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
plot_image[:height0, :width0] = image_pair[0] / 255.0
plot_image[:height1, width0:] = image_pair[1] / 255.0
plt.imshow(plot_image)
plt.axis("off")
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
plt.plot(
[keypoint0_x, keypoint1_x + width0],
[keypoint0_y, keypoint1_y],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
plt.show()
class LightGluePositionalEncoder(nn.Module):
@ -975,10 +1081,4 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
)
__all__ = [
"LightGluePreTrainedModel",
"LightGlueForKeypointMatching",
"LightGlueConfig",
"LightGlueImageProcessor",
"LightGlueImageProcessorFast",
]
__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]

View File

@ -116,23 +116,15 @@ def token_type_ids_mask_function(
# If it's 1 for both query and key/value, we are in an image block
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

View File

@ -78,7 +78,7 @@ def _pad(items, key, padding_value, padding_side):
if isinstance(items[0][key], torch.Tensor):
# Others include `attention_mask` etc...
shape = items[0][key].shape
dim = items[0][key].ndim
dim = len(shape)
if dim == 1:
# We have a list of 1-dim torch tensors, which can be stacked without padding
return torch.cat([item[key] for item in items], dim=0)
@ -93,18 +93,37 @@ def _pad(items, key, padding_value, padding_side):
min_length = min(item[key].shape[1] for item in items)
dtype = items[0][key].dtype
if dim == 2 and max_length == min_length:
tensor = None
if dim == 2:
if max_length == min_length:
# Bypass for `ImageGPT` which doesn't provide a padding value, yet
# we can consistently pad since the size should be matching
return torch.cat([item[key] for item in items], dim=0)
else:
tensor = torch.full([batch_size, max_length] + list(shape[2:]), fill_value=padding_value, dtype=dtype)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
elif dim == 3:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
elif dim == 4:
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
if tensor is None:
raise ValueError(f"Unable to create tensor for padding from {key} with dimension {dim}")
for i, item in enumerate(items):
if dim == 2:
if padding_side == "left":
tensor[i, -len(item[key][0]) :] = item[key][0]
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0])] = item[key][0]
tensor[i, : len(item[key][0])] = item[key][0].clone()
elif dim == 3:
if padding_side == "left":
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
elif dim == 4:
if padding_side == "left":
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
return tensor
else:

View File

@ -152,8 +152,6 @@ class TextGenerationPipeline(Pipeline):
continue_final_message=None,
skip_special_tokens=None,
tokenizer_encode_kwargs=None,
tools=None,
documents=None,
**generate_kwargs,
):
# preprocess kwargs
@ -172,11 +170,6 @@ class TextGenerationPipeline(Pipeline):
preprocess_params["max_length"] = max_length
generate_kwargs["max_length"] = max_length
if tools is not None:
preprocess_params["tools"] = tools
if documents is not None:
preprocess_params["documents"] = documents
if prefix is not None:
preprocess_params["prefix"] = prefix
if prefix:
@ -342,8 +335,6 @@ class TextGenerationPipeline(Pipeline):
max_length=None,
continue_final_message=None,
tokenizer_encode_kwargs=None,
tools=None,
documents=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
@ -368,8 +359,6 @@ class TextGenerationPipeline(Pipeline):
continue_final_message=continue_final_message,
return_dict=True,
return_tensors="pt",
tools=tools,
documents=documents,
**tokenizer_kwargs,
)
else:
@ -525,12 +514,7 @@ class TextGenerationPipeline(Pipeline):
]
else:
# When we're not starting from a prefill, the output is a new assistant message
if self.tokenizer.response_schema:
assistant_message = self.tokenizer.parse_response(all_text)
else:
# If there's no schema, then we have to assume it's all content
assistant_message = {"role": "assistant", "content": all_text}
all_text = list(prompt_text.messages) + [assistant_message]
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text}
for key, values in split_keys.items():
record[key] = values[idx]

View File

@ -772,6 +772,8 @@ class ProcessorMixin(PushToHubMixin):
kwargs (`dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
save_jinja_files = kwargs.pop("save_jinja_files", True)
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
@ -794,7 +796,8 @@ class ProcessorMixin(PushToHubMixin):
# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
if attribute_name == "tokenizer":
attribute.save_pretrained(save_directory)
# Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
elif attribute._auto_class is not None:
custom_object_save(attribute, save_directory, config=attribute)
@ -809,16 +812,19 @@ class ProcessorMixin(PushToHubMixin):
# plus we save chat_template in its own file
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE)
output_chat_template_file_legacy = os.path.join(save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE)
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly
if isinstance(self.chat_template, str):
if self.chat_template is not None:
is_single_template = isinstance(self.chat_template, str)
if save_jinja_files and is_single_template:
# New format for single templates is to save them as chat_template.jinja
with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
f.write(self.chat_template)
logger.info(f"chat template saved in {output_chat_template_file_jinja}")
elif isinstance(self.chat_template, dict):
elif save_jinja_files and not is_single_template:
# New format for multiple templates is to save the default as chat_template.jinja
# and the other templates in the chat_templates/ directory
for template_name, template in self.chat_template.items():
@ -832,6 +838,21 @@ class ProcessorMixin(PushToHubMixin):
with open(template_filepath, "w", encoding="utf-8") as f:
f.write(template)
logger.info(f"chat template saved in {template_filepath}")
elif is_single_template:
# Legacy format for single templates: Put them in chat_template.json
chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file_legacy}")
elif self.chat_template is not None:
# At this point we have multiple templates in the legacy format, which is not supported
# chat template dicts are saved to chat_template.json as lists of dicts with fixed key names.
raise ValueError(
"Multiple chat templates are not supported in the legacy format. Please save them as "
"separate files using the `save_jinja_files` argument."
)
# Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers
self.to_json_file(output_processor_file)
@ -1045,6 +1066,9 @@ class ProcessorMixin(PushToHubMixin):
if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1:
chat_templates = chat_templates["default"] # Flatten when we just have a single template/file
if chat_templates:
kwargs["chat_template"] = chat_templates
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
@ -1069,13 +1093,14 @@ class ProcessorMixin(PushToHubMixin):
else:
logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}")
if processor_dict.get("chat_template") is not None:
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
logger.warning_once(
"Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
"in the processor's config. Make sure to move your template to its own file."
)
elif chat_templates:
processor_dict["chat_template"] = chat_templates
if "chat_template" in kwargs:
processor_dict["chat_template"] = kwargs.pop("chat_template")
# Audio tokenizer needs to load the model checkpoint first, because the saved
# json file contains only references to the model path and repo id

View File

@ -102,7 +102,6 @@ from .utils import (
is_huggingface_hub_greater_or_equal,
is_ipex_available,
is_jinja_available,
is_jmespath_available,
is_jumanpp_available,
is_kernels_available,
is_levenshtein_available,
@ -509,13 +508,6 @@ def require_jinja(test_case):
return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
def require_jmespath(test_case):
"""
Decorator marking a test that requires jmespath. These tests are skipped when jmespath isn't installed.
"""
return unittest.skipUnless(is_jmespath_available(), "test requires jmespath")(test_case)
def require_onnx(test_case):
return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)

View File

@ -1864,6 +1864,8 @@ class MistralCommonTokenizer(PushToHubMixin):
Returns:
A tuple of `str`: The files saved.
"""
# `save_jinja_files`` must be skipped to be able to save from a processor
kwargs.pop("save_jinja_files", None)
if kwargs:
raise ValueError(
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.save_pretrained`."

View File

@ -61,7 +61,6 @@ from .utils import (
requires_backends,
to_py_obj,
)
from .utils.chat_parsing_utils import recursive_parse
from .utils.chat_template_utils import render_jinja_template
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
@ -1430,8 +1429,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# we reconstruct that into a single dict while loading them.
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
self.response_schema = kwargs.pop("response_schema", None)
super().__init__(**kwargs)
self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
@ -1858,28 +1855,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return chat_template
def parse_response(self, response: str, schema: Optional[Union[list, dict]] = None):
"""
Converts an output string created by generating text from a model into a parsed message dictionary.
This method is intended for use with chat models, and will read the tokenizer's `response_schema` attribute to
control parsing, although this can be overridden by passing a `response_schema` argument directly.
For more information, see the
[response parsing](https://huggingface.co/docs/transformers/main/en/chat_response_parsing) documentation.
Args:
response (`str`):
The output string generated by the model. This should be the decoded string, not raw tokens.
schema (`Union[list, dict]`, *optional*):
A response schema that indicates the expected output format and how parsing should be performed.
If not provided, the tokenizer's `response_schema` attribute will be used.
"""
if schema is None:
if getattr(self, "response_schema", None) is None:
raise AttributeError("This tokenizer does not have a `response_schema` for parsing chat responses!")
schema = self.response_schema
return recursive_parse(response, schema)
@classmethod
def from_pretrained(
cls,
@ -2431,6 +2406,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
save_directory: Union[str, os.PathLike],
tokenizer_config: dict,
filename_prefix: Optional[str],
save_jinja_files: bool,
):
"""
Writes chat templates out to the save directory if we're using the new format, and removes them from
@ -2445,7 +2421,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
)
saved_raw_chat_template_files = []
if isinstance(self.chat_template, str):
if save_jinja_files and isinstance(self.chat_template, str):
# New format for single templates is to save them as chat_template.jinja
with open(chat_template_file, "w", encoding="utf-8") as f:
f.write(self.chat_template)
@ -2453,7 +2429,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
saved_raw_chat_template_files.append(chat_template_file)
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
elif isinstance(self.chat_template, dict):
elif save_jinja_files and isinstance(self.chat_template, dict):
# New format for multiple templates is to save the default as chat_template.jinja
# and the other templates in the chat_templates/ directory
for template_name, template in self.chat_template.items():
@ -2471,6 +2447,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
saved_raw_chat_template_files.append(template_filepath)
if "chat_template" in tokenizer_config:
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
elif isinstance(self.chat_template, dict):
# Legacy format for multiple templates:
# chat template dicts are saved to the config as lists of dicts with fixed key names.
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
elif self.chat_template is not None:
# Legacy format for single templates: Just make them a key in tokenizer_config.json
tokenizer_config["chat_template"] = self.chat_template
return tokenizer_config, saved_raw_chat_template_files
def save_pretrained(
@ -2516,6 +2499,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Returns:
A tuple of `str`: The files saved.
"""
save_jinja_files = kwargs.pop("save_jinja_files", True)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@ -2553,10 +2538,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
tokenizer_config.update(self.extra_special_tokens)
tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
save_directory, tokenizer_config, filename_prefix
save_directory, tokenizer_config, filename_prefix, save_jinja_files
)
if getattr(self, "response_schema", None) is not None:
tokenizer_config["response_schema"] = self.response_schema
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)

View File

@ -160,7 +160,6 @@ from .import_utils import (
is_in_notebook,
is_ipex_available,
is_jinja_available,
is_jmespath_available,
is_jumanpp_available,
is_kenlm_available,
is_kernels_available,

View File

@ -1,236 +0,0 @@
from __future__ import annotations
import json
import re
from transformers.utils import is_jmespath_available
if is_jmespath_available():
import jmespath
else:
jmespath = None
def _parse_re_match(node_match: re.Match) -> dict | str:
# If the regex has named groups, return a dict of those groups
if node_match.groupdict():
return {key: val for key, val in node_match.groupdict().items() if val is not None}
# Otherwise the regex must have exactly one unnamed group, and we return that
else:
groups = list(node_match.groups())
if len(groups) > 1:
raise ValueError(f"Regex has multiple unnamed groups!\nGroups: {groups}\n")
elif len(groups) == 0:
raise ValueError(f"Regex has no capture groups:\n\n{node_match.group(0)}")
return groups[0]
def recursive_parse(
node_content: str | list | dict,
node_schema: dict,
):
"""
This function takes content and a JSON schema which includes
regex extractors, and recursively parses the content. The output
should be a data structure matching the schema.
Args:
node_content: The content corresponding to this node. Usually a string, but can be something else
if the parent node has multiple capture groups or named groups. In that case,
we generally pass the capture groups straight through to the children of this node
and don't do any parsing at this level.
node_schema: The schema node controlling the parsing.
Returns:
The parsed data structure for the current node.
"""
# If the schema has a const, we just return that value and do absolutely nothing else
if "const" in node_schema:
return node_schema["const"]
# If the node content is None, we return None. EZ.
if node_content is None:
return None
# If not, we have to do a little parsing. First, set some vars and do basic validation
node_type = node_schema["type"]
has_regex = "x-regex" in node_schema or "x-regex-iterator" in node_schema or "x-regex-key-value" in node_schema
if has_regex and not isinstance(node_content, str):
raise TypeError(
"Schema node got a non-string input, but has a regex for parsing.\n"
f"Input: {node_content}\n"
f"Schema: {node_schema}"
)
node_regex = node_schema.get("x-regex")
node_regex_iterator = node_schema.get("x-regex-iterator")
node_regex_to_dict = node_schema.get("x-regex-key-value")
if node_regex is not None:
node_match = re.search(node_regex, node_content, flags=re.DOTALL)
if not node_match:
return None
node_content = _parse_re_match(node_match)
if node_regex_iterator is not None:
if node_type != "array":
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-iterator.\nSchema: {node_schema}")
# Note that this can be applied after a standard node-regex search
node_content = [
_parse_re_match(node_match)
for node_match in re.finditer(node_regex_iterator, node_content, flags=re.DOTALL)
]
if not node_content:
return None
if node_regex_to_dict is not None:
if node_type != "object":
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-key-value.\nSchema: {node_schema}")
# Note that this can be applied after a standard node-regex search
output_content = {}
for node_match in re.finditer(node_regex_to_dict, node_content, flags=re.DOTALL):
match_groups = _parse_re_match(node_match)
if not isinstance(match_groups, dict) or "key" not in match_groups or "value" not in match_groups:
raise ValueError(
f"Regex for x-regex-key-value must have named groups 'key' and 'value'.\n"
f"Match groups: {match_groups}\n"
f"Schema: {node_schema}"
)
output_content[match_groups["key"]] = match_groups["value"]
node_content = output_content
if not node_content:
return None
# Next, if the node has a parser, apply it. We do this after regexes so that the regex can extract
# a substring to parse, if needed.
if "x-parser" in node_schema:
parser = node_schema["x-parser"]
if parser == "json":
if not isinstance(node_content, str):
raise TypeError(
f"Node has JSON parser but got non-string input: {node_content}\nSchema: {node_schema}"
)
parser_args = node_schema.get("x-parser-args", {})
transform = parser_args.get("transform")
allow_non_json = parser_args.get("allow_non_json", False)
try:
parsed_json = json.loads(node_content)
except json.JSONDecodeError as e:
if allow_non_json:
parsed_json = node_content
else:
raise ValueError(
f"Node has JSON parser but could not parse its contents as JSON. You can use the `allow_non_json` parser arg for nodes which may contain JSON or string content.\n\nContent: {node_content}\n\nError: {e}"
)
if transform is not None:
if jmespath is None:
raise ImportError(
"Chat response schema includes a jmespath transformation, but jmespath is not installed. You can install it with `pip install jmespath`."
)
parsed_json = jmespath.search(parser_args["transform"], parsed_json)
node_content = parsed_json
else:
raise ValueError(f"Unknown parser {parser} for schema node: {node_schema}")
# If there's a mapping, apply it now
if "x-mapping" in node_schema:
if not isinstance(node_content, str):
raise TypeError(
f"Schema node with type {node_type} cannot use x-mapping on non-string content.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
mapping = node_schema["x-mapping"]
if node_content in mapping:
node_content = mapping[node_content]
if "x-mapping-regex" in node_schema:
if not isinstance(node_content, str):
raise TypeError(
f"Schema node with type {node_type} cannot use x-mapping-regex on non-string content.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
mapping_regex = node_schema["x-mapping-regex"]
for pattern, replacement in mapping_regex.items():
node_content = re.sub(pattern, replacement, node_content, flags=re.DOTALL)
# Finally, handle parsed content based on schema type and recurse if required
if node_type == "object":
parsed_schema = {}
if isinstance(node_content, str):
# This means we don't have a regex at this level, so all of our child nodes need to parse the whole
# string themselves to extract their value.
if "properties" not in node_schema:
raise ValueError(
f"Object node received string content but has no regex or parser to handle it.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
for key, child_node in node_schema["properties"].items():
child_node_content = recursive_parse(node_content, node_schema["properties"][key])
if child_node_content is not None:
parsed_schema[key] = child_node_content
return parsed_schema
elif isinstance(node_content, dict):
for key, child_node in node_schema.get("properties", {}).items():
if key in node_content:
parsed_schema[key] = recursive_parse(node_content[key], child_node)
elif "default" in child_node:
parsed_schema[key] = child_node["default"]
else:
pass
if "additionalProperties" in node_schema:
for key, value in node_content.items():
if key not in node_schema.get("properties", {}):
parsed_schema[key] = recursive_parse(value, node_schema["additionalProperties"])
return parsed_schema
else:
raise TypeError(f"Expected a dict or str for schema node with type object, got {node_content}")
elif node_type == "array":
if not node_content:
return []
parsed_schema = []
if "items" in node_schema:
if not isinstance(node_content, list):
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
for item in node_content:
parsed_schema.append(recursive_parse(item, node_schema["items"]))
return parsed_schema
elif "prefixItems" in node_schema:
if not isinstance(node_content, list):
if len(node_schema["prefixItems"]) == 1:
# If there's only one prefix item, this is a single item array, we can just wrap the string
node_content = [node_content]
else:
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
if len(node_content) != len(node_schema["prefixItems"]):
raise ValueError(
f"Array node has {len(node_content)} items, but schema only has "
f"{len(node_schema['prefixItems'])} prefixItems defined.\n"
f"Content: {node_content}\n"
f"Schema: {node_schema}"
)
for item, item_schema in zip(node_content, node_schema["prefixItems"]):
parsed_schema.append(recursive_parse(item, item_schema))
return parsed_schema
else:
raise ValueError(f"Array node has no items or prefixItems schema defined.\nSchema: {node_schema}")
elif node_type in ("string", "integer", "number", "boolean"):
if not isinstance(node_content, str):
raise TypeError(f"Expected a string for schema node with type {node_type}, got {node_content}")
if node_type == "integer":
return int(node_content)
elif node_type == "number":
return float(node_content)
elif node_type == "boolean":
if node_content.lower() in ("true", "1"):
return True
elif node_content.lower() in ("false", "0"):
return False
else:
raise ValueError(f"Invalid boolean value: {node_content}")
else:
# String type
return node_content
else:
raise TypeError(f"Unsupported schema type {node_type} for node: {node_content}")

View File

@ -882,6 +882,10 @@ class PushToHubMixin:
```
"""
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
save_jinja_files = deprecated_kwargs.pop(
"save_jinja_files", None
) # TODO: This is only used for testing and should be removed once save_jinja_files becomes the default
repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
if repo_path_or_name is not None:
# Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
@ -927,11 +931,15 @@ class PushToHubMixin:
files_timestamps = self._get_files_timestamps(work_dir)
# Save all files.
if save_jinja_files:
self.save_pretrained(
work_dir,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
save_jinja_files=True,
)
else:
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))

View File

@ -1134,11 +1134,6 @@ def is_jinja_available() -> bool:
return _is_package_available("jinja2")
@lru_cache
def is_jmespath_available() -> bool:
return _is_package_available("jmespath")
@lru_cache
def is_mlx_available() -> bool:
return _is_package_available("mlx")

View File

@ -476,6 +476,20 @@ class ProcessorPushToHubTester(unittest.TestCase):
tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template
)
self.assertEqual(processor.chat_template, chat_template)
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
with TemporaryHubRepo(token=self._token) as tmp_repo:
processor.save_pretrained(
tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
)
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we don't use single-file chat template saving, processor and tokenizer chat templates
# should remain separate
self.assertEqual(
getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template
)
with TemporaryHubRepo(token=self._token) as tmp_repo:
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)

View File

@ -18,7 +18,7 @@ from tests.models.superglue.test_image_processing_superglue import (
SuperGlueImageProcessingTester,
)
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from transformers.utils import is_torch_available, is_vision_available
if is_torch_available():
@ -30,9 +30,6 @@ if is_torch_available():
if is_vision_available():
from transformers import LightGlueImageProcessor
if is_torchvision_available():
from transformers import LightGlueImageProcessorFast
def random_array(size):
return np.random.randint(255, size=size)
@ -93,7 +90,7 @@ class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
@require_vision
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
fast_image_processing_class = LightGlueImageProcessorFast if is_torchvision_available() else None
fast_image_processing_class = None
def setUp(self) -> None:
super().setUp()

View File

@ -263,31 +263,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
],
)
@require_torch
def test_small_chat_model_with_response_parsing(self):
text_generator = pipeline(
task="text-generation",
model="hf-internal-testing/tiny-gpt2-with-chatml-template",
)
# Using `do_sample=False` to force deterministic output
chat = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
]
text_generator.tokenizer.response_schema = {
# A real response schema should probably have things like "role" and "content"
# and "reasoning_content" but it's unlikely we'd get a tiny model to reliably
# output anything like that, so let's keep it simple.
"type": "object",
"properties": {
"first_word": {"type": "string", "x-regex": r"^\s*([a-zA-Z]+)"},
"last_word": {"type": "string", "x-regex": r"([a-zA-Z]+)\s*$"},
},
}
outputs = text_generator(chat, do_sample=False, max_new_tokens=10)
parsed_message = outputs[0]["generated_text"][-1]
self.assertEqual(parsed_message, {"first_word": "factors", "last_word": "factors"})
def get_test_pipeline(
self,
model,

View File

@ -15,7 +15,6 @@
import inspect
import json
import os
import random
import sys
import tempfile
@ -943,15 +942,17 @@ class ProcessorTesterMixin:
if "chat_template" not in {*signature.parameters.keys()}:
self.skipTest("Processor doesn't accept chat templates at input")
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
processor.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname)
with open(Path(tmpdirname, "chat_template.json"), "w") as fp:
json.dump({"chat_template": processor.chat_template}, fp)
os.remove(Path(tmpdirname, "chat_template.jinja"))
processor.save_pretrained(tmpdirname, save_jinja_files=False)
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
# When we don't use single-file chat template saving, processor and tokenizer chat templates
# should remain separate
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
with tempfile.TemporaryDirectory() as tmpdirname:
processor.save_pretrained(tmpdirname)
@ -976,6 +977,12 @@ class ProcessorTesterMixin:
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
with self.assertRaises(ValueError):
# Saving multiple templates in the legacy format is not permitted
with tempfile.TemporaryDirectory() as tmpdirname:
processor.chat_template = {"default": "a", "secondary": "b"}
processor.save_pretrained(tmpdirname, save_jinja_files=False)
@require_torch
def _test_apply_chat_template(
self,

View File

@ -1093,13 +1093,9 @@ class TokenizerTesterMixin:
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
save_files = tokenizer.save_pretrained(tmp_dir_name)
with open(Path(tmp_dir_name, "tokenizer_config.json"), "r") as fp:
tokenizer_config = json.load(fp)
tokenizer_config["chat_template"] = tokenizer.chat_template
with open(Path(tmp_dir_name, "tokenizer_config.json"), "w") as fp:
json.dump(tokenizer_config, fp)
os.remove(Path(tmp_dir_name, "chat_template.jinja"))
save_files = tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
# Check we aren't saving a chat_template.jinja file
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
@ -1159,16 +1155,10 @@ class TokenizerTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.chat_template = {"default": "a", "secondary": "b"}
tokenizer.save_pretrained(tmpdirname)
with open(Path(tmpdirname, "tokenizer_config.json"), "r") as fp:
tokenizer_config = json.load(fp)
tokenizer_config["chat_template"] = [
{"name": k, "template": v} for k, v in tokenizer.chat_template
]
with open(Path(tmpdirname, "tokenizer_config.json"), "w") as fp:
json.dump(tokenizer_config, fp)
os.remove(Path(tmpdirname, "chat_template.jinja"))
os.remove(Path(tmpdirname, "additional_chat_templates"))
tokenizer.save_pretrained(tmpdirname, save_jinja_files=False)
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir())
reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
# When we save as single files, tokenizers and tokenizers share a chat template, which means
@ -1662,16 +1652,29 @@ class TokenizerTesterMixin:
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
for save_jinja_files in (True, False):
tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
# Test that a dict of multiple templates can be serialized and loaded back
tokenizer.save_pretrained(tmp_dir_name)
# Test that save_jinja_files is ignored when there's a dict of multiple templates
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=save_jinja_files)
if save_jinja_files:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
self.assertNotIn("chat_template", config_dict)
self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja"))
)
else:
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[
{"name": "default", "template": "{{'a'}}"},
{"name": "template2", "template": "{{'b'}}"},
],
)
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
@ -1685,14 +1688,7 @@ class TokenizerTesterMixin:
with self.subTest(f"{tokenizer.__class__.__name__}"):
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.chat_template = dummy_template1
tokenizer.save_pretrained(tmp_dir_name)
# Save first template in tokenizer config and second template in jinja file
# Priority should be given to jinja when loading
with open(Path(tmp_dir_name, "tokenizer_config.json"), "r") as fp:
tokenizer_config = json.load(fp)
tokenizer_config["chat_template"] = tokenizer.chat_template
with open(Path(tmp_dir_name, "tokenizer_config.json"), "w") as fp:
json.dump(tokenizer_config, fp)
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
f.write(dummy_template2)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)

View File

@ -1,349 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoProcessor, AutoTokenizer
from transformers.testing_utils import require_jmespath
from transformers.utils.chat_parsing_utils import recursive_parse
cohere_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"},
"tool_calls": {
"x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)",
"x-parser": "json",
"x-parser-args": {
"transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}"
},
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
ernie_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": "<response>\n(.*?)\n?</response>"},
"thinking": {"type": "string", "x-regex": r"(?:^|<think>\s*)(.*?)\s*<\/think>"},
"tool_calls": {
"x-regex-iterator": "<tool_call>(.*?)</tool_call>",
"type": "array",
"items": {
"type": "object",
"x-parser": "json",
"x-parser-args": {"transform": "{type: 'function', function: @}"},
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
gpt_oss_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
"tool_calls": {
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
"arguments": {
"type": "object",
"x-regex": r"<\|message\|>(.*)",
"x-parser": "json",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
smollm_schema = {
"x-regex": r"(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?:<tool_call>(?P<tool_calls>.+?)</tool_call>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string"},
"thinking": {"type": "string"},
"tool_calls": {
"x-parser": "json",
"x-parser-args": {"transform": "[{type: 'function', function: @}]"},
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
qwen3_schema = {
"x-regex": r"^(?:(?:<think>)?\s*(?P<thinking>.+?)\s*</think>)?\s*(?:<tool_call>(?P<tool_calls>.*?)\s*</tool_call>)?\s*(?P<content>.+?)?\s*$",
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string"},
"thinking": {"type": "string"},
"tool_calls": {
"x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"<function=(\w+)>"},
"arguments": {
"type": "object",
"x-regex-key-value": r"<parameter=(?P<key>\w+)>\n(?P<value>.*?)\n</parameter>",
"additionalProperties": {
"x-parser": "json",
"x-parser-args": {"allow_non_json": True},
},
},
},
},
},
},
},
},
}
@require_jmespath
class ChatSchemaParserTest(unittest.TestCase):
def test_schema_save_load(self):
# Has no schema by default
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.response_schema = ernie_schema
with tempfile.TemporaryDirectory() as tmpdir:
tokenizer.save_pretrained(tmpdir)
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmpdir)
self.assertEqual(reloaded_tokenizer.response_schema, ernie_schema)
# Has no schema by default
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
processor.response_schema = ernie_schema
with tempfile.TemporaryDirectory() as tmpdir:
processor.save_pretrained(tmpdir)
reloaded_processor = AutoProcessor.from_pretrained(tmpdir)
self.assertEqual(reloaded_processor.response_schema, ernie_schema)
def test_tokenizer_method(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
parsed_chat = recursive_parse(model_out, cohere_schema)
tokenizer.response_schema = cohere_schema
tokenizer_parsed_chat = tokenizer.parse_response(model_out)
self.assertEqual(tokenizer_parsed_chat, parsed_chat)
def test_cohere_template(self):
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
parsed_chat = recursive_parse(model_out, cohere_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": "I should call a tool.",
"tool_calls": [
{
"type": "function",
"function": {"name": "simple_tool", "arguments": {"temperature_format": "Celsius"}},
}
],
},
)
def test_ernie_template_with_tools(self):
model_out = 'The user is asking about the weather in Paris today. Let me check the available tools. There\'s a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to "Paris". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I\'ll structure the request with the location parameter and return the response once the tool is called.\n</think>\n\n<tool_call>\n{"name": "get_current_temperature", "arguments": {"location": "Paris"}}\n</tool_call>\n</s>'
parsed_chat = recursive_parse(model_out, ernie_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": "The user is asking about the weather in Paris today. Let me check the available tools. There's a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to \"Paris\". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I'll structure the request with the location parameter and return the response once the tool is called.",
"tool_calls": [
{
"type": "function",
"function": {"name": "get_current_temperature", "arguments": {"location": "Paris"}},
}
],
},
)
def test_ernie_template_no_tools(self):
model_out = "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.\n</think>\n\n<response>\nHello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!\n</response>\n</s>"
parsed_chat = recursive_parse(model_out, ernie_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"content": "Hello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!",
"thinking": "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.",
},
)
def test_gpt_oss_template_with_tool_call(self):
model_out = '<|channel|>analysis<|message|>We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\n "location": "San Francisco, CA"\n}'
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": 'We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.',
"tool_calls": [
{
"type": "function",
"function": {"name": "get_current_weather", "arguments": {"location": "San Francisco, CA"}},
}
],
},
)
def test_gpt_oss_template_no_tool_call(self):
model_out = "<|channel|>analysis<|message|>User asks a simple math question: 2+2 = 4. Provide answer.<|end|><|start|>assistant<|channel|>final<|message|>2"
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"content": "2",
"thinking": "User asks a simple math question: 2+2 = 4. Provide answer.",
},
)
def test_smollm_template_thinking_and_tool_call(self):
model_out = '<think>\nOkay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.\n</think>\n\n<tool_call>{"name": "greet_user", "arguments": {"greeting": "Hello! I\'m doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"}}</tool_call>'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"thinking": 'Okay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.',
"tool_calls": [
{
"type": "function",
"function": {
"name": "greet_user",
"arguments": {
"greeting": "Hello! I'm doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"
},
},
}
],
},
)
def test_smollm_template_tool_call_no_thinking(self):
model_out = '<tool_call>{"name": "get_weather", "arguments": {"city": "Paris"}}</tool_call>'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"tool_calls": [
{"type": "function", "function": {"name": "get_weather", "arguments": {"city": "Paris"}}}
]
},
)
def test_smollm_template_thinking_no_tool_call(self):
model_out = '<think>\nOkay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.</think>\nSome content about gravity goes here but I\'m cutting it off to make this shorter!'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"content": "Some content about gravity goes here but I'm cutting it off to make this shorter!",
"thinking": 'Okay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.',
},
)
def test_qwen3_tool_calls(self):
model_out = '<tool_call>\n<function=get_weather>\n<parameter=locations>\n[{"country": "France", "city": "Paris"}]\n</parameter>\n<parameter=temp_units>\ncelsius\n</parameter>\n</function>\n</tool_call>'
parsed_chat = recursive_parse(model_out, qwen3_schema)
self.assertEqual(
parsed_chat,
{
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {
"locations": [{"country": "France", "city": "Paris"}],
"temp_units": "celsius",
},
},
}
]
},
)

View File

@ -138,6 +138,14 @@ class TokenizerPushToHubTester(unittest.TestCase):
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.chat_template = "test template"
with TemporaryHubRepo(token=self._token) as tmp_repo:
tokenizer.save_pretrained(
tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
)
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
with TemporaryHubRepo(token=self._token) as tmp_repo:
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)