mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-22 18:34:37 +08:00
Compare commits
15 Commits
rm_last_ke
...
serve-quan
Author | SHA1 | Date | |
---|---|---|---|
b734e7c35e | |||
ccbd1eceb3 | |||
b68b48ce88 | |||
72d8e7bb3c | |||
747fcfa227 | |||
a6506fa478 | |||
72ffb3d1d2 | |||
f525309408 | |||
ffa68ba7b8 | |||
eab734d23c | |||
b604f62b6b | |||
35fff29efd | |||
1cdd0bf0fb | |||
907f206a1b | |||
86ba65350b |
@ -88,8 +88,6 @@
|
||||
title: Tool use
|
||||
- local: chat_templating_writing
|
||||
title: Writing a chat template
|
||||
- local: chat_response_parsing
|
||||
title: Response parsing
|
||||
title: Chat with models
|
||||
- sections:
|
||||
- local: serving
|
||||
|
@ -95,12 +95,9 @@ print(tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):]))
|
||||
|
||||
The chat model called the `get_current_temperature` tool with the correct parameters from the docstring. It inferred France as the location based on Paris, and that it should use Celsius for the units of temperature.
|
||||
|
||||
A model **cannot actually call the tool itself**. It requests a tool call, and it's your job to handle the call and append it and the result to the chat history. For
|
||||
models that support [response parsing](./chat_response_parsing), the response parsing will be handled automatically, and you can just use
|
||||
[`~PreTrainedTokenizer.parse_response] to extract the tool call. For other models, you'll need to manually translate the output
|
||||
string into a tool call dict.
|
||||
A model **cannot actually call the tool itself**. It requests a tool call, and it's your job to handle the call and append it and the result to the chat history.
|
||||
|
||||
Regardless of the approach you use, the tool call should go in the `tool_calls` key of an `assistant` message. This is the recommended API, and should be supported by the chat template of most tool-using models.
|
||||
Hold the call in the `tool_calls` key of an `assistant` message. This is the recommended API, and should be supported by the chat template of most tool-using models.
|
||||
|
||||
> [!WARNING]
|
||||
> Although `tool_calls` is similar to the OpenAI API, the OpenAI API uses a JSON string as its `tool_calls` format. This may cause errors or strange model behavior if used in Transformers, which expects a dict.
|
||||
|
@ -1,233 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Response Parsing
|
||||
|
||||
It is increasingly common for chat models to generate structured outputs, rather than just a single reply string.
|
||||
The most common uses for structured outputs are [tool calling](./chat_extras) and [reasoning models](https://huggingface.co/reasoning-course).
|
||||
Tool calling models can output tool calls, containing the name of the tool to call and any arguments to be passed to it,
|
||||
while reasoning models often output reasoning steps as a "chain of thought". Some recent models even use both of these,
|
||||
and may output reasoning and/or one or more tool calls before their final answer.
|
||||
|
||||
Models with structured outputs pose a challenge for chat templating, because the output needs to be parsed before it
|
||||
can be appended to the chat. For a concrete example, let's say we ask [GPT-OSS](https://huggingface.co/openai/gpt-oss-120b)
|
||||
what the weather is like, and it thinks and decides to call a tool. Here's what the raw model output might look like:
|
||||
|
||||
```txt
|
||||
<|start|>analysis<|message|>The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco).
|
||||
So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data.
|
||||
So we should call get_current_weather with location "San Francisco, CA". Let's do that.
|
||||
We will call function get_current_weather.<|end|><|start|>commentary to=functions.get_current_weather<|channel|>commentary <|constrain|>json<|message|>{"location":"San Francisco, CA"}<|call|>
|
||||
}
|
||||
```
|
||||
|
||||
But if you want to append this to a chat, you'll need to format it as a chat message dict, like this:
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "The user asks: \"What is the weather like in SF?\" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data. So we should call get_current_weather with location \"San Francisco, CA\". Let's do that.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Chat **templates** give us a way to turn messages into formatted input for a model, but we need something else to
|
||||
parse model output back into a standard message dict. This is what chat **parsing** is for.
|
||||
|
||||
## The [parse_response](~PreTrainedTokenizerBase.parse_response) method
|
||||
|
||||
Parsing a chat response on a model that supports it is straightforward. Simply take the raw, decoded output from
|
||||
[generate](`~generation.GenerationMixin.generate`), and pass it to the tokenizer's [parse_response](~PreTrainedTokenizerBase.parse_response) method:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
checkpoint = "HuggingFaceTB/SmolLM3-3B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype="auto", device_map="auto")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey! Can you summarize the end of the Cold War as briefly as possible? Like, comically briefly. It should really leave out almost most of the relevant information."
|
||||
}
|
||||
]
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(input_ids, max_new_tokens=1024)[0, input_ids.shape[1]:]
|
||||
out_text = tokenizer.decode(outputs)
|
||||
parsed = tokenizer.parse_response(out_text)
|
||||
print(parsed.keys())
|
||||
```
|
||||
|
||||
And you should get:
|
||||
|
||||
```text
|
||||
dict_keys(['thinking', 'content'])
|
||||
```
|
||||
|
||||
And that's all you need to start using response parsing! `parse_response` should return a complete message dict that is ready to be appended to the chat history.
|
||||
When the tokenizer does not support response parsing, `parse_response` will throw an error. We hope to add support
|
||||
to more tokenizers over time.
|
||||
|
||||
## Developers: Understanding a simple response schema
|
||||
|
||||
Under the hood, `parse_response` uses a **JSON schema** to parse the model output. A JSON schema represents
|
||||
the structure of the output message dict. The schema is augmented with additional fields that indicate how the
|
||||
output message string should be parsed into the expected format. Let's take a look at the schema for a SmolLM response,
|
||||
excluding tool calls for now:
|
||||
|
||||
```python
|
||||
{
|
||||
"x-regex": "(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string"},
|
||||
"thinking": {"type": "string"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
We can see that the schema describes a JSON "object" (a `dict`, in other words) with three keys: `role`, `content`, and `thinking`.
|
||||
Because all assistant responses have the role "assistant", the `role` key is a `const`(ant). The other two keys are strings, extracted
|
||||
from the named groups in the regex in the `x-regex` field.
|
||||
|
||||
Like chat templates, response schemas are set as a property of the tokenizer. To enable response parsing, all you need
|
||||
to do is set `tokenizer.response_schema` to a valid schema dict, and `tokenizer.parse_response()` will work! Again, like
|
||||
chat templates, this schema will be saved with the processor, so once you set it, you can use `save_pretrained()` or `push_to_hub()` to
|
||||
save and share the schema.
|
||||
|
||||
## Developers: Complex schemas
|
||||
|
||||
Now, let's look at a more complex schema, which includes tool calls, to gain more of an understanding of the parser
|
||||
internals. For this, we'll use the `GPT-OSS` schema. GPT-OSS emits both tool calls and thinking blocks, and it uses
|
||||
an unusual format where model responses are tagged with one of three "channels": `commentary` for things like
|
||||
tool calls, `analysis` for chain of thought blocks, and `final` for messages intended to be sent to the user.
|
||||
A full message where the model calls a tool named `get_current_weather` might look like this, with some extra linebreaks added for clarity:
|
||||
|
||||
```text
|
||||
<|channel|>analysis<|message|>
|
||||
The user asks: "What is the weather like in SF?" So we need to get the current weather in San Francisco, CA.
|
||||
We need to call get_current_weather function. So we should call get_current_weather with location "San Francisco, CA".
|
||||
<|end|>
|
||||
<|start|>assistant<|channel|>commentary
|
||||
to=functions.get_current_weather <|constrain|>json<|message|>
|
||||
{
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
<|call|>
|
||||
```
|
||||
|
||||
Parsing proceeds recursively; the output of a regex (or other parser) at one level becomes the input to the nodes below it.
|
||||
In other words, don't feel like you have to parse the entire output in one enormous regex! Instead, start with the schema,
|
||||
and then add regexes to extract the relevant chunks as you go. Here's a schema that will parse it, with some
|
||||
explanatory comments:
|
||||
|
||||
```python
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
# "content" and "thinking" are both similar to the previous example, and just extract a single string
|
||||
# However, rather than using a single regex with named groups to extract both, we use a regex in each subkey.
|
||||
# When an object node has no parser/regex, the entire input string is passed to all of its children, so
|
||||
# parsing can either be done with named groups at the object level, or with separate regexes at the property level.
|
||||
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
|
||||
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
|
||||
"tool_calls": {
|
||||
# "x-regex-iterator" uses re.findall to find multiple possible manages, and returns them as an
|
||||
# array/list. You don't need to worry about array handling, though - each item in the array will be
|
||||
# parsed by the `items` schema, so just write the schema for a single item.
|
||||
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
# A const property is a fixed value, and the input has no effect on it.
|
||||
"type": {"const": "function"},
|
||||
# Here, we wrap the entire tool call dict in a `{"function": ...}` block. The input string is passed through to it unchanged.
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex": "<\|message\|>(.*)",
|
||||
# The "x-parser" field indicates that the extracted string should be parsed as JSON.
|
||||
# The output is then passed to the schema nodes below and recursive parsing continues.
|
||||
"x-parser": "json",
|
||||
"additionalProperties": {"type": "any"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
## Developers: Understanding the parser logic
|
||||
|
||||
The parser follows a few simple rules:
|
||||
|
||||
1. Each level of the schema receives input from the level above, applies any regex or parser it has, and then passes the output to its children.
|
||||
2. The root level receives the entire decoded model output string as input.
|
||||
3. If a node has structured content after parsing (for example, if the regex has named groups and returns a dict, or if the parser returns a dict or list),
|
||||
then that structured content is mapped to the node's children, and each child node receives its corresponding value as input.
|
||||
4. If an `object` (dict) node has unstructured (string) output, then the entire string is passed to all of its children. This allows child nodes
|
||||
to handle parsing individually rather than requiring a single parent regex to extract all keys at once.
|
||||
5. If an `array` (list) node has unstructured (string) output, then this throws an error.
|
||||
|
||||
There is a small set of allowable `x-` keys that indicate how parsing should be done at each node:
|
||||
- `x-regex`: A regex string to apply to the input. If the regex has named groups, the output is a dict of group names to values. Named groups should only be used in `object` nodes.
|
||||
Otherwise, the regex must have exactly one unnamed capturing group, and the output is the value of that group as a string.
|
||||
- `x-regex-iterator`: A regex string to apply to the input using `re.findall()`. The output is a list of all matches.
|
||||
This should only be used in `array` nodes, and the regex must have exactly one unnamed capturing group. The output is distributed to
|
||||
the node's `items` schema.
|
||||
- `x-parser`: Calls a built-in parser to apply to the input. Currently, the only supported parser is `json`, which parses the input string as JSON.
|
||||
The output is passed to the child nodes for further parsing. Note that the `json` parser can return deeply nested output - in this case, the output
|
||||
will be progressively unwrapped as it is passed through child nodes. The child nodes do not need additional `x-parser` or `x-regex` fields in this case,
|
||||
but their structure must match the structure of the parsed JSON.
|
||||
- `x-parser-args`: Only allowed in conjunction with `x-parser`. This is a dict of additional arguments that control parsing. Right now, the only supported
|
||||
argument is `transform`, which specifies a `jmespath` transformation to apply to the output. This is useful when the JSON parser returns a structure
|
||||
that needs to be modified to match the schema.
|
||||
- `x-regex-key-value`: This is rarely necessary, but it can be useful when parsing key-value pairs in non-JSON format where the names of the keys are not known
|
||||
in advance, such as when a model emits XML tool calls with arbitrary argument names. The regex must have exactly two named capturing groups,
|
||||
`key` and `value`, and the output is a dict mapping keys to values. This should only be used in `object` nodes.
|
||||
|
||||
In general, multiple regexes/parsers cannot be combined at the same level. The exception is that `x-regex`, returning a single string, can be combined with the other parsers. In this case,
|
||||
`x-regex` is applied first, and then the output is passed to the other parser, either `x-regex-iterator`, `x-parser`, or `x-regex-key-value`.
|
||||
|
||||
Putting these ideas together, you can see that the input flows through the schema, being parsed at each level and then distributed to child nodes. Each level
|
||||
only needs to extract the input content that is relevant for that part of the schema, and can then let its child nodes handle the rest. Internally, this is handled
|
||||
with a parser function that receives input, applies any regexes/parsers at the current level, then maps the result to its child nodes before recursively calling itself on each of them.
|
||||
Recursion terminates when it reaches leaf nodes, usually primitive types like `string` or `number`, which simply return the input they receive.
|
@ -147,13 +147,6 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
## LightGlueImageProcessorFast
|
||||
|
||||
[[autodoc]] LightGlueImageProcessorFast
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
## LightGlueForKeypointMatching
|
||||
|
||||
[[autodoc]] LightGlueForKeypointMatching
|
||||
|
@ -383,6 +383,30 @@ transformers serve \
|
||||
--attn_implementation "sdpa"
|
||||
```
|
||||
|
||||
### Quantization
|
||||
|
||||
transformers serve is compatible with all [quantization methods](https://huggingface.co/docs/transformers/main/quantization/overview) supported in transformers. Quantization can significantly reduce memory usage and improve inference speed, with two main workflows: pre-quantized models and on-the-fly quantization.
|
||||
|
||||
#### Pre-quantized Models
|
||||
|
||||
For models that are already quantized (e.g., GPTQ, AWQ, bitsandbytes), simply choose a quantized model name for serving.
|
||||
Make sure to install the required libraries listed in the quantization documentation.
|
||||
|
||||
> [!TIP]
|
||||
> Pre-quantized models generally provide the best balance of performance and accuracy.
|
||||
|
||||
#### On the fly quantization
|
||||
|
||||
If you want to quantize a model at runtime, you can specify the --quantization flag in the CLI. Note that not all quantization methods support on-the-fly conversion. The full list of supported methods is available in the quantization [overview](https://huggingface.co/docs/transformers/main/quantization/overview).
|
||||
|
||||
Currently, with transformers serve, we only supports the some methods: ["bitsandbytes-4bit", "bitsandbytes-8bit"]
|
||||
|
||||
For example, to enable 4-bit quantization with bitsandbytes, you need to pass add `--quantization bitsandbytes-4bit`:
|
||||
|
||||
```sh
|
||||
transformers serve --quantization bitsandbytes-4bit
|
||||
```
|
||||
|
||||
### Performance tips
|
||||
|
||||
- Use an efficient attention backend when available:
|
||||
@ -397,6 +421,4 @@ transformers serve \
|
||||
|
||||
- `--dtype {bfloat16|float16}` typically improve throughput and memory use vs. `float32`
|
||||
|
||||
- `--load_in_4bit`/`--load_in_8bit` can reduce memory footprint for LoRA setups
|
||||
|
||||
- `--force-model <repo_id>` avoids per-request model hints and helps produce stable, repeatable runs
|
||||
|
@ -125,23 +125,15 @@ def token_type_ids_mask_function(
|
||||
# If it's 1 for both query and key/value, we are in an image block
|
||||
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
||||
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
||||
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
||||
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
|
||||
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
||||
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
||||
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
||||
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
||||
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
||||
|
||||
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
||||
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
||||
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
||||
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
||||
|
||||
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
||||
|
||||
# This is bidirectional attention whenever we are dealing with image tokens
|
||||
return is_image_block & same_image_block
|
||||
|
3
setup.py
3
setup.py
@ -117,7 +117,6 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"jinja2>=3.1.0",
|
||||
"jmespath>=1.0.1",
|
||||
"kenlm",
|
||||
"kernels>=0.10.2,<0.11",
|
||||
"librosa",
|
||||
@ -295,7 +294,7 @@ extras["num2words"] = deps_list("num2words")
|
||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||
extras["tiktoken"] = deps_list("tiktoken", "blobfile")
|
||||
extras["mistral-common"] = deps_list("mistral-common[opencv]")
|
||||
extras["chat_template"] = deps_list("jinja2", "jmespath")
|
||||
extras["chat_template"] = deps_list("jinja2")
|
||||
extras["testing"] = (
|
||||
deps_list(
|
||||
"pytest",
|
||||
|
@ -377,14 +377,10 @@ class Serve:
|
||||
help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||
),
|
||||
] = None,
|
||||
load_in_8bit: Annotated[
|
||||
bool, typer.Option(help="Whether to use 8 bit precision for the base model - works only with LoRA.")
|
||||
] = False,
|
||||
load_in_4bit: Annotated[
|
||||
bool, typer.Option(help="Whether to use 4 bit precision for the base model - works only with LoRA.")
|
||||
] = False,
|
||||
bnb_4bit_quant_type: Annotated[str, typer.Option(help="Quantization type.")] = "nf4",
|
||||
use_bnb_nested_quant: Annotated[bool, typer.Option(help="Whether to use nested quantization.")] = False,
|
||||
quantization: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(help="Which quantization method to use. choices: 'bitsandbytes-4bit', 'bitsandbytes-8bit'"),
|
||||
] = None,
|
||||
host: Annotated[str, typer.Option(help="Interface the server will listen to.")] = "localhost",
|
||||
port: Annotated[int, typer.Option(help="Port the server will listen to.")] = 8000,
|
||||
model_timeout: Annotated[
|
||||
@ -424,10 +420,7 @@ class Serve:
|
||||
self.dtype = dtype
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.attn_implementation = attn_implementation
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.use_bnb_nested_quant = use_bnb_nested_quant
|
||||
self.quantization = quantization
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.model_timeout = model_timeout
|
||||
@ -1688,19 +1681,14 @@ class Serve:
|
||||
Returns:
|
||||
`Optional[BitsAndBytesConfig]`: The quantization config.
|
||||
"""
|
||||
if self.load_in_4bit:
|
||||
if self.quantization == "bitsandbytes-4bit":
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
# For consistency with model weights, we use the same value as `dtype`
|
||||
bnb_4bit_compute_dtype=self.dtype,
|
||||
bnb_4bit_quant_type=self.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=self.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=self.dtype,
|
||||
)
|
||||
elif self.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
elif self.quantization == "bitsandbytes-8bit":
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
@ -1750,7 +1738,6 @@ class Serve:
|
||||
revision=revision,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
dtype = self.dtype if self.dtype in ["auto", None] else getattr(torch, self.dtype)
|
||||
quantization_config = self.get_quantization_config()
|
||||
|
||||
@ -1758,19 +1745,15 @@ class Serve:
|
||||
"revision": revision,
|
||||
"attn_implementation": self.attn_implementation,
|
||||
"dtype": dtype,
|
||||
"device_map": "auto",
|
||||
"device_map": self.device,
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
if quantization_config is not None:
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
architecture = getattr(transformers, config.architectures[0])
|
||||
model = architecture.from_pretrained(model_id, **model_kwargs)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(self.device)
|
||||
|
||||
has_default_max_length = (
|
||||
model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
|
||||
)
|
||||
|
@ -27,7 +27,6 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"jinja2": "jinja2>=3.1.0",
|
||||
"jmespath": "jmespath>=1.0.1",
|
||||
"kenlm": "kenlm",
|
||||
"kernels": "kernels>=0.10.2,<0.11",
|
||||
"librosa": "librosa",
|
||||
|
@ -27,6 +27,7 @@ from ...utils.metrics import traced
|
||||
logger = logging.getLogger("ContinuousBatchingLogger")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -164,7 +164,6 @@ except ImportError:
|
||||
|
||||
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
|
||||
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
|
||||
"mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "revision": "clean-mamba-ssm"},
|
||||
}
|
||||
|
||||
_KERNEL_MODULE_MAPPING: dict[str, Optional[ModuleType]] = {}
|
||||
@ -236,7 +235,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]
|
||||
if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
|
||||
return mapping[kernel_name]
|
||||
if kernel_name not in _HUB_KERNEL_MAPPING:
|
||||
logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
|
||||
logger.warning(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
|
||||
mapping[kernel_name] = None
|
||||
return None
|
||||
if _kernels_available:
|
||||
@ -244,9 +243,8 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, Optional[ModuleType]]
|
||||
|
||||
try:
|
||||
repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
|
||||
revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
|
||||
version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
|
||||
kernel = get_kernel(repo_id, revision=revision, version=version)
|
||||
kernel = get_kernel(repo_id, version=version)
|
||||
mapping[kernel_name] = kernel
|
||||
except FileNotFoundError:
|
||||
mapping[kernel_name] = None
|
||||
|
0
src/transformers/kernels/__init__.py
Normal file
0
src/transformers/kernels/__init__.py
Normal file
15
src/transformers/kernels/falcon_mamba/__init__.py
Normal file
15
src/transformers/kernels/falcon_mamba/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .selective_scan_with_ln_interface import mamba_inner_fn
|
@ -0,0 +1,525 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
try:
|
||||
import causal_conv1d_cuda
|
||||
except ImportError:
|
||||
causal_conv1d_cuda = None
|
||||
|
||||
import mamba_ssm
|
||||
import selective_scan_cuda
|
||||
|
||||
|
||||
# For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
|
||||
if hasattr(mamba_ssm.ops.triton, "layernorm"):
|
||||
from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
|
||||
else:
|
||||
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
|
||||
|
||||
|
||||
class SelectiveScanFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3:
|
||||
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_B = True
|
||||
if C.dim() == 3:
|
||||
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
||||
ctx.squeeze_C = True
|
||||
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.has_z = z is not None
|
||||
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
||||
if not ctx.has_z:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
else:
|
||||
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
||||
out_z = rest[0]
|
||||
return out_z if not return_last_state else (out_z, last_state)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
if not ctx.has_z:
|
||||
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
||||
z = None
|
||||
out = None
|
||||
else:
|
||||
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
# Here we just pass in None and dz will be allocated in the C++ code.
|
||||
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
dout,
|
||||
x,
|
||||
out,
|
||||
None,
|
||||
ctx.delta_softplus,
|
||||
False, # option to recompute out_z, not used here
|
||||
)
|
||||
dz = rest[0] if ctx.has_z else None
|
||||
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
||||
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
||||
return (
|
||||
du,
|
||||
ddelta,
|
||||
dA,
|
||||
dB,
|
||||
dC,
|
||||
dD if D is not None else None,
|
||||
dz,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_forward(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps=1e-6,
|
||||
is_rms_norm=True,
|
||||
):
|
||||
# x (b l) d
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
|
||||
# y (b l) d
|
||||
return y
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
"""if return_last_state is True, returns (out, last_state)
|
||||
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
||||
not considered in the backward pass.
|
||||
"""
|
||||
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
||||
|
||||
|
||||
def selective_scan_ref(
|
||||
u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
|
||||
):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
if A.is_complex():
|
||||
if is_variable_B:
|
||||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||||
if is_variable_C:
|
||||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||||
else:
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate))
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
if y.is_complex():
|
||||
y = y.real * 2
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
class MambaInnerFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B=None,
|
||||
C=None,
|
||||
D=None,
|
||||
delta_bias=None,
|
||||
B_proj_bias=None,
|
||||
C_proj_bias=None,
|
||||
delta_softplus=True,
|
||||
checkpoint_lvl=1,
|
||||
b_rms_weight=None,
|
||||
c_rms_weight=None,
|
||||
dt_rms_weight=None,
|
||||
b_c_dt_rms_eps=1e-6,
|
||||
):
|
||||
"""
|
||||
xz: (batch, dim, seqlen)
|
||||
"""
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
assert checkpoint_lvl in [0, 1]
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
if torch.is_autocast_enabled():
|
||||
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
out_proj_bias = (
|
||||
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
|
||||
)
|
||||
if xz.stride(-1) != 1:
|
||||
xz = xz.contiguous()
|
||||
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
|
||||
# We're being very careful here about the layout, to avoid extra transposes.
|
||||
# We want delta to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
|
||||
ctx.is_variable_B = B is None
|
||||
ctx.is_variable_C = C is None
|
||||
ctx.B_proj_bias_is_None = B_proj_bias is None
|
||||
ctx.C_proj_bias_is_None = C_proj_bias is None
|
||||
if B is None: # variable B
|
||||
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
||||
if B_proj_bias is not None:
|
||||
B = B + B_proj_bias.to(dtype=B.dtype)
|
||||
if not A.is_complex():
|
||||
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C is None: # variable C
|
||||
C = x_dbl[:, -d_state:] # (bl dstate)
|
||||
if C_proj_bias is not None:
|
||||
C = C + C_proj_bias.to(dtype=C.dtype)
|
||||
if not A.is_complex():
|
||||
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
else:
|
||||
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
|
||||
else:
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
|
||||
if b_rms_weight is not None:
|
||||
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if c_rms_weight is not None:
|
||||
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if dt_rms_weight is not None:
|
||||
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
||||
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
||||
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
||||
|
||||
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
|
||||
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
||||
)
|
||||
ctx.delta_softplus = delta_softplus
|
||||
ctx.out_proj_bias_is_None = out_proj_bias is None
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.b_rms_weight = b_rms_weight
|
||||
ctx.c_rms_weight = c_rms_weight
|
||||
ctx.dt_rms_weight = dt_rms_weight
|
||||
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
||||
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
|
||||
conv1d_out, delta = None, None
|
||||
ctx.save_for_backward(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_dbl,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
scan_intermediates,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
out,
|
||||
)
|
||||
return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dout):
|
||||
# dout: (batch, seqlen, dim)
|
||||
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
||||
(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_dbl,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
scan_intermediates,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
out,
|
||||
) = ctx.saved_tensors
|
||||
L = xz.shape[-1]
|
||||
delta_rank = delta_proj_weight.shape[1]
|
||||
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
if dout.stride(-1) != 1:
|
||||
dout = dout.contiguous()
|
||||
if ctx.checkpoint_lvl == 1:
|
||||
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
|
||||
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
|
||||
if dt_rms_weight is not None:
|
||||
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
||||
delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
||||
if b_rms_weight is not None:
|
||||
# Recompute & RMSNorm B
|
||||
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
if c_rms_weight is not None:
|
||||
# Recompute & RMSNorm C
|
||||
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
||||
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
||||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
||||
|
||||
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
||||
# backward of selective_scan_cuda with the backward of chunk).
|
||||
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
||||
dx, dz = dxz.chunk(2, dim=1)
|
||||
dout = rearrange(dout, "b l e -> e (b l)")
|
||||
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
||||
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
|
||||
conv1d_out,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
dout_y,
|
||||
scan_intermediates,
|
||||
out,
|
||||
dz,
|
||||
ctx.delta_softplus,
|
||||
True, # option to recompute out_z
|
||||
)
|
||||
dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
|
||||
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
||||
dD = dD if D is not None else None
|
||||
dx_dbl = torch.empty_like(x_dbl)
|
||||
dB_proj_bias = None
|
||||
if ctx.is_variable_B:
|
||||
if not A.is_complex():
|
||||
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
||||
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
||||
dB = None
|
||||
dC_proj_bias = None
|
||||
if ctx.is_variable_C:
|
||||
if not A.is_complex():
|
||||
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
||||
else:
|
||||
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
|
||||
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
||||
dx_dbl[:, -d_state:] = dC # (bl d)
|
||||
dC = None
|
||||
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
||||
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
||||
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
||||
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
||||
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
|
||||
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
|
||||
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
|
||||
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
||||
# backward of conv1d with the backward of chunk).
|
||||
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
||||
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
|
||||
)
|
||||
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
||||
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
||||
return (
|
||||
dxz,
|
||||
dconv1d_weight,
|
||||
dconv1d_bias,
|
||||
dx_proj_weight,
|
||||
ddelta_proj_weight,
|
||||
dout_proj_weight,
|
||||
dout_proj_bias,
|
||||
dA,
|
||||
dB,
|
||||
dC,
|
||||
dD,
|
||||
ddelta_bias if delta_bias is not None else None,
|
||||
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
||||
dB_proj_bias,
|
||||
dC_proj_bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def mamba_inner_fn(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B=None,
|
||||
C=None,
|
||||
D=None,
|
||||
delta_bias=None,
|
||||
B_proj_bias=None,
|
||||
C_proj_bias=None,
|
||||
delta_softplus=True,
|
||||
checkpoint_lvl=1,
|
||||
b_rms_weight=None,
|
||||
c_rms_weight=None,
|
||||
dt_rms_weight=None,
|
||||
b_c_dt_rms_eps=1e-6,
|
||||
):
|
||||
return MambaInnerFn.apply(
|
||||
xz,
|
||||
conv1d_weight,
|
||||
conv1d_bias,
|
||||
x_proj_weight,
|
||||
delta_proj_weight,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
delta_bias,
|
||||
B_proj_bias,
|
||||
C_proj_bias,
|
||||
delta_softplus,
|
||||
checkpoint_lvl,
|
||||
b_rms_weight,
|
||||
c_rms_weight,
|
||||
dt_rms_weight,
|
||||
b_c_dt_rms_eps,
|
||||
)
|
@ -121,7 +121,7 @@ else:
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
||||
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
|
||||
("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")),
|
||||
("lightglue", ("LightGlueImageProcessor", None)),
|
||||
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
||||
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
||||
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
||||
|
@ -1318,7 +1318,7 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
output sound according to specific predefined voice.
|
||||
"""
|
||||
)
|
||||
class BarkModel(BarkPreTrainedModel, GenerationMixin):
|
||||
class BarkModel(BarkPreTrainedModel):
|
||||
config: BarkConfig
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -34,7 +34,10 @@ from ...integrations.hub_kernels import lazy_load_kernel
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.import_utils import is_mambapy_available
|
||||
from ...utils.import_utils import (
|
||||
is_mamba_ssm_available,
|
||||
is_mambapy_available,
|
||||
)
|
||||
from .configuration_falcon_mamba import FalconMambaConfig
|
||||
|
||||
|
||||
@ -43,6 +46,14 @@ if is_mambapy_available():
|
||||
else:
|
||||
pscan = None
|
||||
|
||||
if is_mamba_ssm_available():
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
|
||||
from ...kernels.falcon_mamba import mamba_inner_fn
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -235,12 +246,6 @@ class FalconMambaMixer(nn.Module):
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
|
||||
if mamba_ssm is not None
|
||||
else (None, None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
@ -272,12 +277,7 @@ class FalconMambaMixer(nn.Module):
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
mamba_ssm.selective_state_update,
|
||||
mamba_ssm.selective_scan_fn,
|
||||
mamba_ssm.mamba_inner_fn,
|
||||
)
|
||||
|
||||
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
||||
contextualized_states = mamba_inner_fn(
|
||||
projected_states,
|
||||
@ -506,16 +506,6 @@ class FalconMambaMixer(nn.Module):
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
(
|
||||
mamba_ssm.selective_state_update,
|
||||
mamba_ssm.selective_scan_fn,
|
||||
mamba_ssm.mamba_inner_fn,
|
||||
)
|
||||
if mamba_ssm is not None
|
||||
else (None, None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ from torch import nn
|
||||
from ...integrations.hub_kernels import lazy_load_kernel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.import_utils import (
|
||||
is_mamba_ssm_available,
|
||||
is_mambapy_available,
|
||||
)
|
||||
from ..mamba.configuration_mamba import MambaConfig
|
||||
@ -45,6 +46,14 @@ if is_mambapy_available():
|
||||
else:
|
||||
pscan = None
|
||||
|
||||
if is_mamba_ssm_available():
|
||||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||
|
||||
from ...kernels.falcon_mamba import mamba_inner_fn
|
||||
else:
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
|
||||
|
||||
|
||||
class FalconMambaConfig(MambaConfig):
|
||||
"""
|
||||
@ -251,12 +260,6 @@ class FalconMambaMixer(MambaMixer):
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
(mamba_ssm.selective_state_update, mamba_ssm.selective_scan_fn, mamba_ssm.mamba_inner_fn)
|
||||
if mamba_ssm is not None
|
||||
else (None, None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
@ -299,12 +302,7 @@ class FalconMambaMixer(MambaMixer):
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
mamba_ssm.selective_state_update,
|
||||
mamba_ssm.selective_scan_fn,
|
||||
mamba_ssm.mamba_inner_fn,
|
||||
)
|
||||
|
||||
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
|
||||
contextualized_states = mamba_inner_fn(
|
||||
projected_states,
|
||||
@ -532,16 +530,6 @@ class FalconMambaMixer(MambaMixer):
|
||||
if causal_conv1d is not None
|
||||
else (None, None)
|
||||
)
|
||||
mamba_ssm = lazy_load_kernel("mamba-ssm")
|
||||
selective_state_update, selective_scan_fn, mamba_inner_fn = (
|
||||
(
|
||||
mamba_ssm.selective_state_update,
|
||||
mamba_ssm.selective_scan_fn,
|
||||
mamba_ssm.mamba_inner_fn,
|
||||
)
|
||||
if mamba_ssm is not None
|
||||
else (None, None, None)
|
||||
)
|
||||
is_fast_path_available = all(
|
||||
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
|
||||
)
|
||||
|
@ -447,7 +447,7 @@ def convert_transformer_weights(
|
||||
return zip([], [])
|
||||
else:
|
||||
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
|
||||
elif f"{_TRANSFORMER_EMBEDDER}/mm_" in path:
|
||||
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
|
||||
if not _INCLUDE_VISION_ENCODER.value:
|
||||
return zip([], [])
|
||||
|
||||
@ -553,7 +553,7 @@ def convert(
|
||||
continue
|
||||
|
||||
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
|
||||
update_tree(f"model.{path}", weights, config.vision_config.dtype)
|
||||
update_tree(path, weights, config.vision_config.dtype)
|
||||
else:
|
||||
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
|
||||
if not _INCLUDE_VISION_ENCODER.value:
|
||||
|
@ -768,23 +768,15 @@ def token_type_ids_mask_function(
|
||||
# If it's 1 for both query and key/value, we are in an image block
|
||||
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
||||
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
||||
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
||||
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
|
||||
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
||||
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
||||
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
||||
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
||||
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
||||
|
||||
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
||||
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
||||
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
||||
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
||||
|
||||
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
||||
|
||||
# This is bidirectional attention whenever we are dealing with image tokens
|
||||
return is_image_block & same_image_block
|
||||
|
@ -20,7 +20,6 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_lightglue import *
|
||||
from .image_processing_lightglue import *
|
||||
from .image_processing_lightglue_fast import *
|
||||
from .modeling_lightglue import *
|
||||
else:
|
||||
import sys
|
||||
|
@ -17,6 +17,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -39,28 +40,20 @@ from ...image_utils import (
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...processing_utils import ImagesKwargs
|
||||
from ...utils import TensorType, logging, requires_backends
|
||||
from ...utils import TensorType, is_matplotlib_available, logging, requires_backends
|
||||
from ...utils.import_utils import requires
|
||||
from .modeling_lightglue import LightGlueKeypointMatchingOutput
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LightGlueImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
r"""
|
||||
do_grayscale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
|
||||
"""
|
||||
|
||||
do_grayscale: bool
|
||||
|
||||
|
||||
def is_grayscale(
|
||||
image: np.ndarray,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
@ -468,5 +461,60 @@ class LightGlueImageProcessor(BaseImageProcessor):
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
|
||||
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
|
||||
matplotlib to be installed.
|
||||
|
||||
.. deprecated::
|
||||
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
|
||||
a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
|
||||
"Use `visualize_keypoint_matching` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
else:
|
||||
raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
|
||||
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
|
||||
plot_image[:height0, :width0] = image_pair[0] / 255.0
|
||||
plot_image[:height1, width0:] = image_pair[1] / 255.0
|
||||
plt.imshow(plot_image)
|
||||
plt.axis("off")
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
plt.plot(
|
||||
[keypoint0_x, keypoint1_x + width0],
|
||||
[keypoint0_y, keypoint1_y],
|
||||
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||
alpha=0.9,
|
||||
linewidth=0.5,
|
||||
)
|
||||
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
|
||||
plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
|
||||
plt.show()
|
||||
|
||||
|
||||
__all__ = ["LightGlueImageProcessor"]
|
||||
|
@ -1,302 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_lightglue.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
ImageType,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_type,
|
||||
is_pil_image,
|
||||
is_valid_image,
|
||||
is_vision_available,
|
||||
to_numpy_array,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, auto_docstring
|
||||
from .image_processing_lightglue import LightGlueImageProcessorKwargs
|
||||
from .modeling_lightglue import LightGlueKeypointMatchingOutput
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def _is_valid_image(image):
|
||||
return is_pil_image(image) or (
|
||||
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
|
||||
)
|
||||
|
||||
|
||||
def flatten_pair_images(images):
|
||||
# Handle the pair validation and flattening similar to slow processor
|
||||
if isinstance(images, list):
|
||||
if len(images) == 2 and all((_is_valid_image(image) or isinstance(image, torch.Tensor)) for image in images):
|
||||
# Single pair of images - keep as is, they'll be processed by the base class
|
||||
return images
|
||||
elif all(
|
||||
isinstance(image_pair, list)
|
||||
and len(image_pair) == 2
|
||||
and all(_is_valid_image(image) or isinstance(image, torch.Tensor) for image in image_pair)
|
||||
for image_pair in images
|
||||
):
|
||||
# Multiple pairs - flatten them
|
||||
images = [image for image_pair in images for image in image_pair]
|
||||
return images
|
||||
raise ValueError(
|
||||
"Input images must be a one of the following :",
|
||||
" - A pair of PIL images.",
|
||||
" - A pair of 3D arrays.",
|
||||
" - A list of pairs of PIL images.",
|
||||
" - A list of pairs of 3D arrays.",
|
||||
)
|
||||
|
||||
|
||||
def is_grayscale(
|
||||
image: "torch.Tensor",
|
||||
):
|
||||
"""Checks if an image is grayscale (all RGB channels are identical)."""
|
||||
if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1:
|
||||
return True
|
||||
return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all(
|
||||
image[..., 1, :, :] == image[..., 2, :, :]
|
||||
)
|
||||
|
||||
|
||||
def convert_to_grayscale(
|
||||
image: "torch.Tensor",
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor.
|
||||
|
||||
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
|
||||
channel, because of an issue that is discussed in :
|
||||
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
|
||||
|
||||
Args:
|
||||
image (torch.Tensor):
|
||||
The image to convert.
|
||||
"""
|
||||
if is_grayscale(image):
|
||||
return image
|
||||
return F.rgb_to_grayscale(image, num_output_channels=3)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class LightGlueImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
size = {"height": 480, "width": 640}
|
||||
default_to_square = False
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
rescale_factor = 1 / 255
|
||||
do_normalize = None
|
||||
valid_kwargs = LightGlueImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[LightGlueImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LightGlueImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs,
|
||||
) -> ImageInput:
|
||||
# we need to handle image pairs validation and flattening
|
||||
return flatten_pair_images(images)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
size: Union[dict[str, int], SizeDict],
|
||||
rescale_factor: float,
|
||||
do_rescale: bool,
|
||||
do_resize: bool,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_grayscale: bool,
|
||||
disable_grouping: bool,
|
||||
return_tensors: Union[str, TensorType],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_rescale:
|
||||
stacked_images = self.rescale(stacked_images, rescale_factor)
|
||||
if do_grayscale:
|
||||
stacked_images = convert_to_grayscale(stacked_images)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
# Convert back to pairs format
|
||||
image_pairs = [processed_images[i : i + 2] for i in range(0, len(processed_images), 2)]
|
||||
|
||||
# Stack each pair into a single tensor to match slow processor format
|
||||
stacked_pairs = [torch.stack(pair, dim=0) for pair in image_pairs]
|
||||
|
||||
# Return in same format as slow processor
|
||||
image_pairs = torch.stack(stacked_pairs, dim=0) if return_tensors else stacked_pairs
|
||||
|
||||
return BatchFeature(data={"pixel_values": image_pairs})
|
||||
|
||||
def post_process_keypoint_matching(
|
||||
self,
|
||||
outputs: LightGlueKeypointMatchingOutput,
|
||||
target_sizes: Union[TensorType, list[tuple]],
|
||||
threshold: float = 0.0,
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
|
||||
with coordinates absolute to the original image sizes.
|
||||
Args:
|
||||
outputs ([`KeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
|
||||
target size `(height, width)` of each image in the batch. This must be the original image size (before
|
||||
any processing).
|
||||
threshold (`float`, *optional*, defaults to 0.0):
|
||||
Threshold to filter out the matches with low scores.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
|
||||
of the pair, the matching scores and the matching indices.
|
||||
"""
|
||||
if outputs.matches.shape[0] != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||
if not all(len(target_size) == 2 for target_size in target_sizes):
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
if isinstance(target_sizes, list):
|
||||
image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
|
||||
else:
|
||||
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||
)
|
||||
image_pair_sizes = target_sizes
|
||||
|
||||
keypoints = outputs.keypoints.clone()
|
||||
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
|
||||
keypoints = keypoints.to(torch.int32)
|
||||
|
||||
results = []
|
||||
for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
|
||||
# Filter out matches with low scores
|
||||
valid_matches = torch.logical_and(scores > threshold, matches > -1)
|
||||
|
||||
matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
|
||||
matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
|
||||
matching_scores = scores[0][valid_matches[0]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"keypoints0": matched_keypoints0,
|
||||
"keypoints1": matched_keypoints1,
|
||||
"matching_scores": matching_scores,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
images,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images:
|
||||
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
from .image_processing_lightglue import validate_and_format_image_pairs
|
||||
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = torch.zeros((max(height0, height1), width0 + width1, 3), dtype=torch.uint8)
|
||||
plot_image[:height0, :width0] = torch.from_numpy(image_pair[0])
|
||||
plot_image[:height1, width0:] = torch.from_numpy(image_pair[1])
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image.numpy())
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return r, g, b
|
||||
|
||||
|
||||
__all__ = ["LightGlueImageProcessorFast"]
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
@ -21,24 +22,25 @@ from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...image_utils import ImageInput, is_vision_available, to_numpy_array
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TensorType, auto_docstring, logging
|
||||
from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging
|
||||
from ...utils.generic import can_return_tuple
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..auto.modeling_auto import AutoModelForKeypointDetection
|
||||
from ..clip.modeling_clip import CLIPMLP
|
||||
from ..cohere.modeling_cohere import apply_rotary_pos_emb
|
||||
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
|
||||
from ..superglue.image_processing_superglue import (
|
||||
SuperGlueImageProcessor,
|
||||
SuperGlueImageProcessorKwargs,
|
||||
)
|
||||
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
|
||||
from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs
|
||||
from ..superpoint import SuperPointConfig
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -215,10 +217,6 @@ class LightGlueKeypointMatchingOutput(ModelOutput):
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class LightGlueImageProcessorKwargs(SuperGlueImageProcessorKwargs):
|
||||
pass
|
||||
|
||||
|
||||
class LightGlueImageProcessor(SuperGlueImageProcessor):
|
||||
def post_process_keypoint_matching(
|
||||
self,
|
||||
@ -228,15 +226,123 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
|
||||
|
||||
|
||||
class LightGlueImageProcessorFast(SuperGlueImageProcessorFast):
|
||||
def post_process_keypoint_matching(
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
outputs: LightGlueKeypointMatchingOutput,
|
||||
target_sizes: Union[TensorType, list[tuple]],
|
||||
threshold: float = 0.0,
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
|
||||
images: ImageInput,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
|
||||
plot_image[:height0, :width0] = image_pair[0]
|
||||
plot_image[:height1, width0:] = image_pair[1]
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image)
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
|
||||
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
|
||||
matplotlib to be installed.
|
||||
|
||||
.. deprecated::
|
||||
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
|
||||
a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
|
||||
"Use `visualize_keypoint_matching` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
else:
|
||||
raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
|
||||
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
|
||||
plot_image[:height0, :width0] = image_pair[0] / 255.0
|
||||
plot_image[:height1, width0:] = image_pair[1] / 255.0
|
||||
plt.imshow(plot_image)
|
||||
plt.axis("off")
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
plt.plot(
|
||||
[keypoint0_x, keypoint1_x + width0],
|
||||
[keypoint0_y, keypoint1_y],
|
||||
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||
alpha=0.9,
|
||||
linewidth=0.5,
|
||||
)
|
||||
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
|
||||
plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
|
||||
plt.show()
|
||||
|
||||
|
||||
class LightGluePositionalEncoder(nn.Module):
|
||||
@ -975,10 +1081,4 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LightGluePreTrainedModel",
|
||||
"LightGlueForKeypointMatching",
|
||||
"LightGlueConfig",
|
||||
"LightGlueImageProcessor",
|
||||
"LightGlueImageProcessorFast",
|
||||
]
|
||||
__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]
|
||||
|
@ -116,23 +116,15 @@ def token_type_ids_mask_function(
|
||||
# If it's 1 for both query and key/value, we are in an image block
|
||||
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
|
||||
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
|
||||
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
|
||||
safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
|
||||
token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx]
|
||||
token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0)
|
||||
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx]
|
||||
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
|
||||
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
|
||||
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
|
||||
|
||||
image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx]
|
||||
image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1)
|
||||
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx]
|
||||
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
|
||||
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
|
||||
|
||||
is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx
|
||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
|
||||
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
|
||||
|
||||
# This is bidirectional attention whenever we are dealing with image tokens
|
||||
return is_image_block & same_image_block
|
||||
|
@ -78,7 +78,7 @@ def _pad(items, key, padding_value, padding_side):
|
||||
if isinstance(items[0][key], torch.Tensor):
|
||||
# Others include `attention_mask` etc...
|
||||
shape = items[0][key].shape
|
||||
dim = items[0][key].ndim
|
||||
dim = len(shape)
|
||||
if dim == 1:
|
||||
# We have a list of 1-dim torch tensors, which can be stacked without padding
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
@ -93,18 +93,37 @@ def _pad(items, key, padding_value, padding_side):
|
||||
min_length = min(item[key].shape[1] for item in items)
|
||||
dtype = items[0][key].dtype
|
||||
|
||||
if dim == 2 and max_length == min_length:
|
||||
tensor = None
|
||||
if dim == 2:
|
||||
if max_length == min_length:
|
||||
# Bypass for `ImageGPT` which doesn't provide a padding value, yet
|
||||
# we can consistently pad since the size should be matching
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
else:
|
||||
tensor = torch.full([batch_size, max_length] + list(shape[2:]), fill_value=padding_value, dtype=dtype)
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
elif dim == 3:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
elif dim == 4:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
if tensor is None:
|
||||
raise ValueError(f"Unable to create tensor for padding from {key} with dimension {dim}")
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :] = item[key][0]
|
||||
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0])] = item[key][0]
|
||||
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
||||
elif dim == 3:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
elif dim == 4:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
else:
|
||||
|
@ -152,8 +152,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
continue_final_message=None,
|
||||
skip_special_tokens=None,
|
||||
tokenizer_encode_kwargs=None,
|
||||
tools=None,
|
||||
documents=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# preprocess kwargs
|
||||
@ -172,11 +170,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
preprocess_params["max_length"] = max_length
|
||||
generate_kwargs["max_length"] = max_length
|
||||
|
||||
if tools is not None:
|
||||
preprocess_params["tools"] = tools
|
||||
if documents is not None:
|
||||
preprocess_params["documents"] = documents
|
||||
|
||||
if prefix is not None:
|
||||
preprocess_params["prefix"] = prefix
|
||||
if prefix:
|
||||
@ -342,8 +335,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
max_length=None,
|
||||
continue_final_message=None,
|
||||
tokenizer_encode_kwargs=None,
|
||||
tools=None,
|
||||
documents=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
||||
@ -368,8 +359,6 @@ class TextGenerationPipeline(Pipeline):
|
||||
continue_final_message=continue_final_message,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
tools=tools,
|
||||
documents=documents,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
@ -525,12 +514,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
]
|
||||
else:
|
||||
# When we're not starting from a prefill, the output is a new assistant message
|
||||
if self.tokenizer.response_schema:
|
||||
assistant_message = self.tokenizer.parse_response(all_text)
|
||||
else:
|
||||
# If there's no schema, then we have to assume it's all content
|
||||
assistant_message = {"role": "assistant", "content": all_text}
|
||||
all_text = list(prompt_text.messages) + [assistant_message]
|
||||
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
|
||||
record = {"generated_text": all_text}
|
||||
for key, values in split_keys.items():
|
||||
record[key] = values[idx]
|
||||
|
@ -772,6 +772,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
save_jinja_files = kwargs.pop("save_jinja_files", True)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
@ -794,7 +796,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
|
||||
# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
|
||||
if attribute_name == "tokenizer":
|
||||
attribute.save_pretrained(save_directory)
|
||||
# Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
|
||||
attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
|
||||
elif attribute._auto_class is not None:
|
||||
custom_object_save(attribute, save_directory, config=attribute)
|
||||
|
||||
@ -809,16 +812,19 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# plus we save chat_template in its own file
|
||||
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
|
||||
output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE)
|
||||
output_chat_template_file_legacy = os.path.join(save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE)
|
||||
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
|
||||
|
||||
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
|
||||
# to avoid serializing chat template in json config file. So let's get it from `self` directly
|
||||
if isinstance(self.chat_template, str):
|
||||
if self.chat_template is not None:
|
||||
is_single_template = isinstance(self.chat_template, str)
|
||||
if save_jinja_files and is_single_template:
|
||||
# New format for single templates is to save them as chat_template.jinja
|
||||
with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
|
||||
f.write(self.chat_template)
|
||||
logger.info(f"chat template saved in {output_chat_template_file_jinja}")
|
||||
elif isinstance(self.chat_template, dict):
|
||||
elif save_jinja_files and not is_single_template:
|
||||
# New format for multiple templates is to save the default as chat_template.jinja
|
||||
# and the other templates in the chat_templates/ directory
|
||||
for template_name, template in self.chat_template.items():
|
||||
@ -832,6 +838,21 @@ class ProcessorMixin(PushToHubMixin):
|
||||
with open(template_filepath, "w", encoding="utf-8") as f:
|
||||
f.write(template)
|
||||
logger.info(f"chat template saved in {template_filepath}")
|
||||
elif is_single_template:
|
||||
# Legacy format for single templates: Put them in chat_template.json
|
||||
chat_template_json_string = (
|
||||
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
|
||||
)
|
||||
with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer:
|
||||
writer.write(chat_template_json_string)
|
||||
logger.info(f"chat template saved in {output_chat_template_file_legacy}")
|
||||
elif self.chat_template is not None:
|
||||
# At this point we have multiple templates in the legacy format, which is not supported
|
||||
# chat template dicts are saved to chat_template.json as lists of dicts with fixed key names.
|
||||
raise ValueError(
|
||||
"Multiple chat templates are not supported in the legacy format. Please save them as "
|
||||
"separate files using the `save_jinja_files` argument."
|
||||
)
|
||||
|
||||
# Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers
|
||||
self.to_json_file(output_processor_file)
|
||||
@ -1045,6 +1066,9 @@ class ProcessorMixin(PushToHubMixin):
|
||||
if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1:
|
||||
chat_templates = chat_templates["default"] # Flatten when we just have a single template/file
|
||||
|
||||
if chat_templates:
|
||||
kwargs["chat_template"] = chat_templates
|
||||
|
||||
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
|
||||
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
|
||||
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
||||
@ -1069,13 +1093,14 @@ class ProcessorMixin(PushToHubMixin):
|
||||
else:
|
||||
logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}")
|
||||
|
||||
if processor_dict.get("chat_template") is not None:
|
||||
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
|
||||
logger.warning_once(
|
||||
"Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
|
||||
"in the processor's config. Make sure to move your template to its own file."
|
||||
)
|
||||
elif chat_templates:
|
||||
processor_dict["chat_template"] = chat_templates
|
||||
|
||||
if "chat_template" in kwargs:
|
||||
processor_dict["chat_template"] = kwargs.pop("chat_template")
|
||||
|
||||
# Audio tokenizer needs to load the model checkpoint first, because the saved
|
||||
# json file contains only references to the model path and repo id
|
||||
|
@ -102,7 +102,6 @@ from .utils import (
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_ipex_available,
|
||||
is_jinja_available,
|
||||
is_jmespath_available,
|
||||
is_jumanpp_available,
|
||||
is_kernels_available,
|
||||
is_levenshtein_available,
|
||||
@ -509,13 +508,6 @@ def require_jinja(test_case):
|
||||
return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
|
||||
|
||||
|
||||
def require_jmespath(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires jmespath. These tests are skipped when jmespath isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_jmespath_available(), "test requires jmespath")(test_case)
|
||||
|
||||
|
||||
def require_onnx(test_case):
|
||||
return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
|
||||
|
||||
|
@ -1864,6 +1864,8 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
Returns:
|
||||
A tuple of `str`: The files saved.
|
||||
"""
|
||||
# `save_jinja_files`` must be skipped to be able to save from a processor
|
||||
kwargs.pop("save_jinja_files", None)
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.save_pretrained`."
|
||||
|
@ -61,7 +61,6 @@ from .utils import (
|
||||
requires_backends,
|
||||
to_py_obj,
|
||||
)
|
||||
from .utils.chat_parsing_utils import recursive_parse
|
||||
from .utils.chat_template_utils import render_jinja_template
|
||||
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
|
||||
|
||||
@ -1430,8 +1429,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# we reconstruct that into a single dict while loading them.
|
||||
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
|
||||
|
||||
self.response_schema = kwargs.pop("response_schema", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
|
||||
@ -1858,28 +1855,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
return chat_template
|
||||
|
||||
def parse_response(self, response: str, schema: Optional[Union[list, dict]] = None):
|
||||
"""
|
||||
Converts an output string created by generating text from a model into a parsed message dictionary.
|
||||
This method is intended for use with chat models, and will read the tokenizer's `response_schema` attribute to
|
||||
control parsing, although this can be overridden by passing a `response_schema` argument directly.
|
||||
|
||||
For more information, see the
|
||||
[response parsing](https://huggingface.co/docs/transformers/main/en/chat_response_parsing) documentation.
|
||||
|
||||
Args:
|
||||
response (`str`):
|
||||
The output string generated by the model. This should be the decoded string, not raw tokens.
|
||||
schema (`Union[list, dict]`, *optional*):
|
||||
A response schema that indicates the expected output format and how parsing should be performed.
|
||||
If not provided, the tokenizer's `response_schema` attribute will be used.
|
||||
"""
|
||||
if schema is None:
|
||||
if getattr(self, "response_schema", None) is None:
|
||||
raise AttributeError("This tokenizer does not have a `response_schema` for parsing chat responses!")
|
||||
schema = self.response_schema
|
||||
return recursive_parse(response, schema)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@ -2431,6 +2406,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
save_directory: Union[str, os.PathLike],
|
||||
tokenizer_config: dict,
|
||||
filename_prefix: Optional[str],
|
||||
save_jinja_files: bool,
|
||||
):
|
||||
"""
|
||||
Writes chat templates out to the save directory if we're using the new format, and removes them from
|
||||
@ -2445,7 +2421,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
saved_raw_chat_template_files = []
|
||||
if isinstance(self.chat_template, str):
|
||||
if save_jinja_files and isinstance(self.chat_template, str):
|
||||
# New format for single templates is to save them as chat_template.jinja
|
||||
with open(chat_template_file, "w", encoding="utf-8") as f:
|
||||
f.write(self.chat_template)
|
||||
@ -2453,7 +2429,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
saved_raw_chat_template_files.append(chat_template_file)
|
||||
if "chat_template" in tokenizer_config:
|
||||
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
|
||||
elif isinstance(self.chat_template, dict):
|
||||
elif save_jinja_files and isinstance(self.chat_template, dict):
|
||||
# New format for multiple templates is to save the default as chat_template.jinja
|
||||
# and the other templates in the chat_templates/ directory
|
||||
for template_name, template in self.chat_template.items():
|
||||
@ -2471,6 +2447,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
saved_raw_chat_template_files.append(template_filepath)
|
||||
if "chat_template" in tokenizer_config:
|
||||
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
|
||||
elif isinstance(self.chat_template, dict):
|
||||
# Legacy format for multiple templates:
|
||||
# chat template dicts are saved to the config as lists of dicts with fixed key names.
|
||||
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
|
||||
elif self.chat_template is not None:
|
||||
# Legacy format for single templates: Just make them a key in tokenizer_config.json
|
||||
tokenizer_config["chat_template"] = self.chat_template
|
||||
return tokenizer_config, saved_raw_chat_template_files
|
||||
|
||||
def save_pretrained(
|
||||
@ -2516,6 +2499,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
Returns:
|
||||
A tuple of `str`: The files saved.
|
||||
"""
|
||||
save_jinja_files = kwargs.pop("save_jinja_files", True)
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@ -2553,10 +2538,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
tokenizer_config.update(self.extra_special_tokens)
|
||||
|
||||
tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
|
||||
save_directory, tokenizer_config, filename_prefix
|
||||
save_directory, tokenizer_config, filename_prefix, save_jinja_files
|
||||
)
|
||||
if getattr(self, "response_schema", None) is not None:
|
||||
tokenizer_config["response_schema"] = self.response_schema
|
||||
|
||||
if len(self.init_inputs) > 0:
|
||||
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
||||
|
@ -160,7 +160,6 @@ from .import_utils import (
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_jinja_available,
|
||||
is_jmespath_available,
|
||||
is_jumanpp_available,
|
||||
is_kenlm_available,
|
||||
is_kernels_available,
|
||||
|
@ -1,236 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from transformers.utils import is_jmespath_available
|
||||
|
||||
|
||||
if is_jmespath_available():
|
||||
import jmespath
|
||||
else:
|
||||
jmespath = None
|
||||
|
||||
|
||||
def _parse_re_match(node_match: re.Match) -> dict | str:
|
||||
# If the regex has named groups, return a dict of those groups
|
||||
if node_match.groupdict():
|
||||
return {key: val for key, val in node_match.groupdict().items() if val is not None}
|
||||
# Otherwise the regex must have exactly one unnamed group, and we return that
|
||||
else:
|
||||
groups = list(node_match.groups())
|
||||
if len(groups) > 1:
|
||||
raise ValueError(f"Regex has multiple unnamed groups!\nGroups: {groups}\n")
|
||||
elif len(groups) == 0:
|
||||
raise ValueError(f"Regex has no capture groups:\n\n{node_match.group(0)}")
|
||||
return groups[0]
|
||||
|
||||
|
||||
def recursive_parse(
|
||||
node_content: str | list | dict,
|
||||
node_schema: dict,
|
||||
):
|
||||
"""
|
||||
This function takes content and a JSON schema which includes
|
||||
regex extractors, and recursively parses the content. The output
|
||||
should be a data structure matching the schema.
|
||||
|
||||
Args:
|
||||
node_content: The content corresponding to this node. Usually a string, but can be something else
|
||||
if the parent node has multiple capture groups or named groups. In that case,
|
||||
we generally pass the capture groups straight through to the children of this node
|
||||
and don't do any parsing at this level.
|
||||
node_schema: The schema node controlling the parsing.
|
||||
|
||||
Returns:
|
||||
The parsed data structure for the current node.
|
||||
"""
|
||||
|
||||
# If the schema has a const, we just return that value and do absolutely nothing else
|
||||
if "const" in node_schema:
|
||||
return node_schema["const"]
|
||||
|
||||
# If the node content is None, we return None. EZ.
|
||||
if node_content is None:
|
||||
return None
|
||||
|
||||
# If not, we have to do a little parsing. First, set some vars and do basic validation
|
||||
node_type = node_schema["type"]
|
||||
has_regex = "x-regex" in node_schema or "x-regex-iterator" in node_schema or "x-regex-key-value" in node_schema
|
||||
if has_regex and not isinstance(node_content, str):
|
||||
raise TypeError(
|
||||
"Schema node got a non-string input, but has a regex for parsing.\n"
|
||||
f"Input: {node_content}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
|
||||
node_regex = node_schema.get("x-regex")
|
||||
node_regex_iterator = node_schema.get("x-regex-iterator")
|
||||
node_regex_to_dict = node_schema.get("x-regex-key-value")
|
||||
if node_regex is not None:
|
||||
node_match = re.search(node_regex, node_content, flags=re.DOTALL)
|
||||
if not node_match:
|
||||
return None
|
||||
node_content = _parse_re_match(node_match)
|
||||
if node_regex_iterator is not None:
|
||||
if node_type != "array":
|
||||
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-iterator.\nSchema: {node_schema}")
|
||||
# Note that this can be applied after a standard node-regex search
|
||||
node_content = [
|
||||
_parse_re_match(node_match)
|
||||
for node_match in re.finditer(node_regex_iterator, node_content, flags=re.DOTALL)
|
||||
]
|
||||
if not node_content:
|
||||
return None
|
||||
if node_regex_to_dict is not None:
|
||||
if node_type != "object":
|
||||
raise TypeError(f"Schema node with type {node_type} cannot use x-regex-key-value.\nSchema: {node_schema}")
|
||||
# Note that this can be applied after a standard node-regex search
|
||||
output_content = {}
|
||||
for node_match in re.finditer(node_regex_to_dict, node_content, flags=re.DOTALL):
|
||||
match_groups = _parse_re_match(node_match)
|
||||
if not isinstance(match_groups, dict) or "key" not in match_groups or "value" not in match_groups:
|
||||
raise ValueError(
|
||||
f"Regex for x-regex-key-value must have named groups 'key' and 'value'.\n"
|
||||
f"Match groups: {match_groups}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
output_content[match_groups["key"]] = match_groups["value"]
|
||||
node_content = output_content
|
||||
if not node_content:
|
||||
return None
|
||||
|
||||
# Next, if the node has a parser, apply it. We do this after regexes so that the regex can extract
|
||||
# a substring to parse, if needed.
|
||||
if "x-parser" in node_schema:
|
||||
parser = node_schema["x-parser"]
|
||||
if parser == "json":
|
||||
if not isinstance(node_content, str):
|
||||
raise TypeError(
|
||||
f"Node has JSON parser but got non-string input: {node_content}\nSchema: {node_schema}"
|
||||
)
|
||||
parser_args = node_schema.get("x-parser-args", {})
|
||||
transform = parser_args.get("transform")
|
||||
allow_non_json = parser_args.get("allow_non_json", False)
|
||||
try:
|
||||
parsed_json = json.loads(node_content)
|
||||
except json.JSONDecodeError as e:
|
||||
if allow_non_json:
|
||||
parsed_json = node_content
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Node has JSON parser but could not parse its contents as JSON. You can use the `allow_non_json` parser arg for nodes which may contain JSON or string content.\n\nContent: {node_content}\n\nError: {e}"
|
||||
)
|
||||
if transform is not None:
|
||||
if jmespath is None:
|
||||
raise ImportError(
|
||||
"Chat response schema includes a jmespath transformation, but jmespath is not installed. You can install it with `pip install jmespath`."
|
||||
)
|
||||
parsed_json = jmespath.search(parser_args["transform"], parsed_json)
|
||||
node_content = parsed_json
|
||||
else:
|
||||
raise ValueError(f"Unknown parser {parser} for schema node: {node_schema}")
|
||||
|
||||
# If there's a mapping, apply it now
|
||||
if "x-mapping" in node_schema:
|
||||
if not isinstance(node_content, str):
|
||||
raise TypeError(
|
||||
f"Schema node with type {node_type} cannot use x-mapping on non-string content.\n"
|
||||
f"Content: {node_content}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
mapping = node_schema["x-mapping"]
|
||||
if node_content in mapping:
|
||||
node_content = mapping[node_content]
|
||||
|
||||
if "x-mapping-regex" in node_schema:
|
||||
if not isinstance(node_content, str):
|
||||
raise TypeError(
|
||||
f"Schema node with type {node_type} cannot use x-mapping-regex on non-string content.\n"
|
||||
f"Content: {node_content}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
mapping_regex = node_schema["x-mapping-regex"]
|
||||
for pattern, replacement in mapping_regex.items():
|
||||
node_content = re.sub(pattern, replacement, node_content, flags=re.DOTALL)
|
||||
|
||||
# Finally, handle parsed content based on schema type and recurse if required
|
||||
if node_type == "object":
|
||||
parsed_schema = {}
|
||||
if isinstance(node_content, str):
|
||||
# This means we don't have a regex at this level, so all of our child nodes need to parse the whole
|
||||
# string themselves to extract their value.
|
||||
if "properties" not in node_schema:
|
||||
raise ValueError(
|
||||
f"Object node received string content but has no regex or parser to handle it.\n"
|
||||
f"Content: {node_content}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
for key, child_node in node_schema["properties"].items():
|
||||
child_node_content = recursive_parse(node_content, node_schema["properties"][key])
|
||||
if child_node_content is not None:
|
||||
parsed_schema[key] = child_node_content
|
||||
return parsed_schema
|
||||
elif isinstance(node_content, dict):
|
||||
for key, child_node in node_schema.get("properties", {}).items():
|
||||
if key in node_content:
|
||||
parsed_schema[key] = recursive_parse(node_content[key], child_node)
|
||||
elif "default" in child_node:
|
||||
parsed_schema[key] = child_node["default"]
|
||||
else:
|
||||
pass
|
||||
if "additionalProperties" in node_schema:
|
||||
for key, value in node_content.items():
|
||||
if key not in node_schema.get("properties", {}):
|
||||
parsed_schema[key] = recursive_parse(value, node_schema["additionalProperties"])
|
||||
return parsed_schema
|
||||
else:
|
||||
raise TypeError(f"Expected a dict or str for schema node with type object, got {node_content}")
|
||||
elif node_type == "array":
|
||||
if not node_content:
|
||||
return []
|
||||
parsed_schema = []
|
||||
if "items" in node_schema:
|
||||
if not isinstance(node_content, list):
|
||||
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
|
||||
for item in node_content:
|
||||
parsed_schema.append(recursive_parse(item, node_schema["items"]))
|
||||
return parsed_schema
|
||||
elif "prefixItems" in node_schema:
|
||||
if not isinstance(node_content, list):
|
||||
if len(node_schema["prefixItems"]) == 1:
|
||||
# If there's only one prefix item, this is a single item array, we can just wrap the string
|
||||
node_content = [node_content]
|
||||
else:
|
||||
raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
|
||||
if len(node_content) != len(node_schema["prefixItems"]):
|
||||
raise ValueError(
|
||||
f"Array node has {len(node_content)} items, but schema only has "
|
||||
f"{len(node_schema['prefixItems'])} prefixItems defined.\n"
|
||||
f"Content: {node_content}\n"
|
||||
f"Schema: {node_schema}"
|
||||
)
|
||||
for item, item_schema in zip(node_content, node_schema["prefixItems"]):
|
||||
parsed_schema.append(recursive_parse(item, item_schema))
|
||||
return parsed_schema
|
||||
else:
|
||||
raise ValueError(f"Array node has no items or prefixItems schema defined.\nSchema: {node_schema}")
|
||||
elif node_type in ("string", "integer", "number", "boolean"):
|
||||
if not isinstance(node_content, str):
|
||||
raise TypeError(f"Expected a string for schema node with type {node_type}, got {node_content}")
|
||||
if node_type == "integer":
|
||||
return int(node_content)
|
||||
elif node_type == "number":
|
||||
return float(node_content)
|
||||
elif node_type == "boolean":
|
||||
if node_content.lower() in ("true", "1"):
|
||||
return True
|
||||
elif node_content.lower() in ("false", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean value: {node_content}")
|
||||
else:
|
||||
# String type
|
||||
return node_content
|
||||
else:
|
||||
raise TypeError(f"Unsupported schema type {node_type} for node: {node_content}")
|
@ -882,6 +882,10 @@ class PushToHubMixin:
|
||||
```
|
||||
"""
|
||||
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
|
||||
save_jinja_files = deprecated_kwargs.pop(
|
||||
"save_jinja_files", None
|
||||
) # TODO: This is only used for testing and should be removed once save_jinja_files becomes the default
|
||||
|
||||
repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
|
||||
if repo_path_or_name is not None:
|
||||
# Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
|
||||
@ -927,11 +931,15 @@ class PushToHubMixin:
|
||||
files_timestamps = self._get_files_timestamps(work_dir)
|
||||
|
||||
# Save all files.
|
||||
if save_jinja_files:
|
||||
self.save_pretrained(
|
||||
work_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
safe_serialization=safe_serialization,
|
||||
save_jinja_files=True,
|
||||
)
|
||||
else:
|
||||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
|
||||
|
||||
# Update model card if needed:
|
||||
model_card.save(os.path.join(work_dir, "README.md"))
|
||||
|
@ -1134,11 +1134,6 @@ def is_jinja_available() -> bool:
|
||||
return _is_package_available("jinja2")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_jmespath_available() -> bool:
|
||||
return _is_package_available("jmespath")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_mlx_available() -> bool:
|
||||
return _is_package_available("mlx")
|
||||
|
@ -476,6 +476,20 @@ class ProcessorPushToHubTester(unittest.TestCase):
|
||||
tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template
|
||||
)
|
||||
self.assertEqual(processor.chat_template, chat_template)
|
||||
|
||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
processor.save_pretrained(
|
||||
tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
|
||||
)
|
||||
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we don't use single-file chat template saving, processor and tokenizer chat templates
|
||||
# should remain separate
|
||||
self.assertEqual(
|
||||
getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template
|
||||
)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
|
@ -18,7 +18,7 @@ from tests.models.superglue.test_image_processing_superglue import (
|
||||
SuperGlueImageProcessingTester,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -30,9 +30,6 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from transformers import LightGlueImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import LightGlueImageProcessorFast
|
||||
|
||||
|
||||
def random_array(size):
|
||||
return np.random.randint(255, size=size)
|
||||
@ -93,7 +90,7 @@ class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
|
||||
@require_vision
|
||||
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
|
||||
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = LightGlueImageProcessorFast if is_torchvision_available() else None
|
||||
fast_image_processing_class = None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
@ -263,31 +263,6 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_chat_model_with_response_parsing(self):
|
||||
text_generator = pipeline(
|
||||
task="text-generation",
|
||||
model="hf-internal-testing/tiny-gpt2-with-chatml-template",
|
||||
)
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
chat = [
|
||||
{"role": "system", "content": "This is a system message."},
|
||||
{"role": "user", "content": "This is a test"},
|
||||
]
|
||||
text_generator.tokenizer.response_schema = {
|
||||
# A real response schema should probably have things like "role" and "content"
|
||||
# and "reasoning_content" but it's unlikely we'd get a tiny model to reliably
|
||||
# output anything like that, so let's keep it simple.
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_word": {"type": "string", "x-regex": r"^\s*([a-zA-Z]+)"},
|
||||
"last_word": {"type": "string", "x-regex": r"([a-zA-Z]+)\s*$"},
|
||||
},
|
||||
}
|
||||
outputs = text_generator(chat, do_sample=False, max_new_tokens=10)
|
||||
parsed_message = outputs[0]["generated_text"][-1]
|
||||
self.assertEqual(parsed_message, {"first_word": "factors", "last_word": "factors"})
|
||||
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
@ -943,15 +942,17 @@ class ProcessorTesterMixin:
|
||||
if "chat_template" not in {*signature.parameters.keys()}:
|
||||
self.skipTest("Processor doesn't accept chat templates at input")
|
||||
|
||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||
processor.chat_template = "test template"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor.save_pretrained(tmpdirname)
|
||||
with open(Path(tmpdirname, "chat_template.json"), "w") as fp:
|
||||
json.dump({"chat_template": processor.chat_template}, fp)
|
||||
os.remove(Path(tmpdirname, "chat_template.jinja"))
|
||||
|
||||
processor.save_pretrained(tmpdirname, save_jinja_files=False)
|
||||
self.assertTrue(Path(tmpdirname, "chat_template.json").is_file())
|
||||
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||
reloaded_processor = self.processor_class.from_pretrained(tmpdirname)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we don't use single-file chat template saving, processor and tokenizer chat templates
|
||||
# should remain separate
|
||||
self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor.save_pretrained(tmpdirname)
|
||||
@ -976,6 +977,12 @@ class ProcessorTesterMixin:
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Saving multiple templates in the legacy format is not permitted
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor.chat_template = {"default": "a", "secondary": "b"}
|
||||
processor.save_pretrained(tmpdirname, save_jinja_files=False)
|
||||
|
||||
@require_torch
|
||||
def _test_apply_chat_template(
|
||||
self,
|
||||
|
@ -1093,13 +1093,9 @@ class TokenizerTesterMixin:
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
save_files = tokenizer.save_pretrained(tmp_dir_name)
|
||||
with open(Path(tmp_dir_name, "tokenizer_config.json"), "r") as fp:
|
||||
tokenizer_config = json.load(fp)
|
||||
tokenizer_config["chat_template"] = tokenizer.chat_template
|
||||
with open(Path(tmp_dir_name, "tokenizer_config.json"), "w") as fp:
|
||||
json.dump(tokenizer_config, fp)
|
||||
os.remove(Path(tmp_dir_name, "chat_template.jinja"))
|
||||
save_files = tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
|
||||
# Check we aren't saving a chat_template.jinja file
|
||||
self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files))
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
@ -1159,16 +1155,10 @@ class TokenizerTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.chat_template = {"default": "a", "secondary": "b"}
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
with open(Path(tmpdirname, "tokenizer_config.json"), "r") as fp:
|
||||
tokenizer_config = json.load(fp)
|
||||
tokenizer_config["chat_template"] = [
|
||||
{"name": k, "template": v} for k, v in tokenizer.chat_template
|
||||
]
|
||||
with open(Path(tmpdirname, "tokenizer_config.json"), "w") as fp:
|
||||
json.dump(tokenizer_config, fp)
|
||||
os.remove(Path(tmpdirname, "chat_template.jinja"))
|
||||
os.remove(Path(tmpdirname, "additional_chat_templates"))
|
||||
tokenizer.save_pretrained(tmpdirname, save_jinja_files=False)
|
||||
self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file())
|
||||
self.assertFalse(Path(tmpdirname, "chat_template.json").is_file())
|
||||
self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir())
|
||||
reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
# When we save as single files, tokenizers and tokenizers share a chat template, which means
|
||||
@ -1662,16 +1652,29 @@ class TokenizerTesterMixin:
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
for save_jinja_files in (True, False):
|
||||
tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
# Test that a dict of multiple templates can be serialized and loaded back
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
# Test that save_jinja_files is ignored when there's a dict of multiple templates
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=save_jinja_files)
|
||||
if save_jinja_files:
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
self.assertNotIn("chat_template", config_dict)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja"))
|
||||
)
|
||||
else:
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[
|
||||
{"name": "default", "template": "{{'a'}}"},
|
||||
{"name": "template2", "template": "{{'b'}}"},
|
||||
],
|
||||
)
|
||||
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
@ -1685,14 +1688,7 @@ class TokenizerTesterMixin:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.chat_template = dummy_template1
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
# Save first template in tokenizer config and second template in jinja file
|
||||
# Priority should be given to jinja when loading
|
||||
with open(Path(tmp_dir_name, "tokenizer_config.json"), "r") as fp:
|
||||
tokenizer_config = json.load(fp)
|
||||
tokenizer_config["chat_template"] = tokenizer.chat_template
|
||||
with open(Path(tmp_dir_name, "tokenizer_config.json"), "w") as fp:
|
||||
json.dump(tokenizer_config, fp)
|
||||
tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False)
|
||||
with Path(tmp_dir_name, "chat_template.jinja").open("w") as f:
|
||||
f.write(dummy_template2)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
@ -1,349 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from transformers.testing_utils import require_jmespath
|
||||
from transformers.utils.chat_parsing_utils import recursive_parse
|
||||
|
||||
|
||||
cohere_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"},
|
||||
"thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"},
|
||||
"tool_calls": {
|
||||
"x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)",
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {
|
||||
"transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}"
|
||||
},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "any"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ernie_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": "<response>\n(.*?)\n?</response>"},
|
||||
"thinking": {"type": "string", "x-regex": r"(?:^|<think>\s*)(.*?)\s*<\/think>"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": "<tool_call>(.*?)</tool_call>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"transform": "{type: 'function', function: @}"},
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "any"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
gpt_oss_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
|
||||
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex": r"<\|message\|>(.*)",
|
||||
"x-parser": "json",
|
||||
"additionalProperties": {"type": "any"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
smollm_schema = {
|
||||
"x-regex": r"(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?:<tool_call>(?P<tool_calls>.+?)</tool_call>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string"},
|
||||
"thinking": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"transform": "[{type: 'function', function: @}]"},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "any"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qwen3_schema = {
|
||||
"x-regex": r"^(?:(?:<think>)?\s*(?P<thinking>.+?)\s*</think>)?\s*(?:<tool_call>(?P<tool_calls>.*?)\s*</tool_call>)?\s*(?P<content>.+?)?\s*$",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string"},
|
||||
"thinking": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "x-regex": r"<function=(\w+)>"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex-key-value": r"<parameter=(?P<key>\w+)>\n(?P<value>.*?)\n</parameter>",
|
||||
"additionalProperties": {
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"allow_non_json": True},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@require_jmespath
|
||||
class ChatSchemaParserTest(unittest.TestCase):
|
||||
def test_schema_save_load(self):
|
||||
# Has no schema by default
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.response_schema = ernie_schema
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tokenizer.save_pretrained(tmpdir)
|
||||
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmpdir)
|
||||
self.assertEqual(reloaded_tokenizer.response_schema, ernie_schema)
|
||||
|
||||
# Has no schema by default
|
||||
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
processor.response_schema = ernie_schema
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
processor.save_pretrained(tmpdir)
|
||||
reloaded_processor = AutoProcessor.from_pretrained(tmpdir)
|
||||
self.assertEqual(reloaded_processor.response_schema, ernie_schema)
|
||||
|
||||
def test_tokenizer_method(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
parsed_chat = recursive_parse(model_out, cohere_schema)
|
||||
tokenizer.response_schema = cohere_schema
|
||||
tokenizer_parsed_chat = tokenizer.parse_response(model_out)
|
||||
self.assertEqual(tokenizer_parsed_chat, parsed_chat)
|
||||
|
||||
def test_cohere_template(self):
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
parsed_chat = recursive_parse(model_out, cohere_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "I should call a tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "simple_tool", "arguments": {"temperature_format": "Celsius"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_ernie_template_with_tools(self):
|
||||
model_out = 'The user is asking about the weather in Paris today. Let me check the available tools. There\'s a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to "Paris". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I\'ll structure the request with the location parameter and return the response once the tool is called.\n</think>\n\n<tool_call>\n{"name": "get_current_temperature", "arguments": {"location": "Paris"}}\n</tool_call>\n</s>'
|
||||
parsed_chat = recursive_parse(model_out, ernie_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "The user is asking about the weather in Paris today. Let me check the available tools. There's a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to \"Paris\". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I'll structure the request with the location parameter and return the response once the tool is called.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_temperature", "arguments": {"location": "Paris"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_ernie_template_no_tools(self):
|
||||
model_out = "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.\n</think>\n\n<response>\nHello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!\n</response>\n</s>"
|
||||
parsed_chat = recursive_parse(model_out, ernie_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!",
|
||||
"thinking": "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.",
|
||||
},
|
||||
)
|
||||
|
||||
def test_gpt_oss_template_with_tool_call(self):
|
||||
model_out = '<|channel|>analysis<|message|>We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\n "location": "San Francisco, CA"\n}'
|
||||
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": 'We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.',
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather", "arguments": {"location": "San Francisco, CA"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_gpt_oss_template_no_tool_call(self):
|
||||
model_out = "<|channel|>analysis<|message|>User asks a simple math question: 2+2 = 4. Provide answer.<|end|><|start|>assistant<|channel|>final<|message|>2"
|
||||
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "2",
|
||||
"thinking": "User asks a simple math question: 2+2 = 4. Provide answer.",
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_thinking_and_tool_call(self):
|
||||
model_out = '<think>\nOkay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.\n</think>\n\n<tool_call>{"name": "greet_user", "arguments": {"greeting": "Hello! I\'m doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"}}</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"thinking": 'Okay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.',
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "greet_user",
|
||||
"arguments": {
|
||||
"greeting": "Hello! I'm doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_tool_call_no_thinking(self):
|
||||
model_out = '<tool_call>{"name": "get_weather", "arguments": {"city": "Paris"}}</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"tool_calls": [
|
||||
{"type": "function", "function": {"name": "get_weather", "arguments": {"city": "Paris"}}}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_thinking_no_tool_call(self):
|
||||
model_out = '<think>\nOkay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.</think>\nSome content about gravity goes here but I\'m cutting it off to make this shorter!'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"content": "Some content about gravity goes here but I'm cutting it off to make this shorter!",
|
||||
"thinking": 'Okay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.',
|
||||
},
|
||||
)
|
||||
|
||||
def test_qwen3_tool_calls(self):
|
||||
model_out = '<tool_call>\n<function=get_weather>\n<parameter=locations>\n[{"country": "France", "city": "Paris"}]\n</parameter>\n<parameter=temp_units>\ncelsius\n</parameter>\n</function>\n</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, qwen3_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"locations": [{"country": "France", "city": "Paris"}],
|
||||
"temp_units": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
@ -138,6 +138,14 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
tokenizer.chat_template = "test template"
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.save_pretrained(
|
||||
tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
|
||||
)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
|
Reference in New Issue
Block a user