mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-12 01:04:36 +08:00
Compare commits
36 Commits
v4.52.2
...
ci-test-hu
| Author | SHA1 | Date | |
|---|---|---|---|
| d6c76b64ed | |||
| b01984a51d | |||
| 2b585419b4 | |||
| b59386dc0a | |||
| 211f2b0875 | |||
| 73286d8e29 | |||
| d95c864a25 | |||
| 9895819514 | |||
| dfbee79ca3 | |||
| 1234683309 | |||
| 03a4c024dc | |||
| dcaf47dde5 | |||
| 163138a911 | |||
| f8630c778c | |||
| aa02a5d902 | |||
| b26157d64c | |||
| b369a65480 | |||
| 28d3148b07 | |||
| 7b7bb8df97 | |||
| 5c13cc0f94 | |||
| 71009e4b68 | |||
| d6c34cdcd0 | |||
| ae3e4e2d97 | |||
| 174684a9b6 | |||
| e4decee9c0 | |||
| ddf67d2d73 | |||
| 9a962dd9ed | |||
| 101b3fa4ea | |||
| a21f11fca2 | |||
| 4542086db7 | |||
| 6829936ee0 | |||
| e288ee00d8 | |||
| 711d78d104 | |||
| feec294dea | |||
| cb513e35f9 | |||
| f4ef41c45e |
51
.github/workflows/check_failed_model_tests.yml
vendored
51
.github/workflows/check_failed_model_tests.yml
vendored
@ -39,55 +39,100 @@ jobs:
|
||||
name: ci_results_run_models_gpu
|
||||
path: /transformers/ci_results_run_models_gpu
|
||||
|
||||
- name: Check file
|
||||
working-directory: /transformers
|
||||
run: |
|
||||
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
|
||||
echo "process=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
|
||||
echo "process=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- uses: actions/download-artifact@v4
|
||||
if: ${{ env.process == 'true' }}
|
||||
with:
|
||||
pattern: setup_values*
|
||||
path: setup_values
|
||||
merge-multiple: true
|
||||
|
||||
- name: Prepare some setup values
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ github.sha }}
|
||||
|
||||
- name: Get target commit
|
||||
working-directory: /transformers/utils
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout to `start_sha`
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Environment
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
python3 utils/print_env.py
|
||||
|
||||
- name: Show installed libraries and their versions
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: pip freeze
|
||||
|
||||
- name: Check failed tests
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Show results
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
ls -l new_model_failures_with_bad_commit.json
|
||||
cat new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Checkout back
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
python3 utils/process_bad_commit_report.py
|
||||
@ -95,7 +140,9 @@ jobs:
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
{
|
||||
@ -105,7 +152,7 @@ jobs:
|
||||
} >> "$GITHUB_ENV"
|
||||
|
||||
- name: Send processed report
|
||||
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
|
||||
35
.github/workflows/self-scheduled-caller.yml
vendored
35
.github/workflows/self-scheduled-caller.yml
vendored
@ -8,8 +8,43 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- run_scheduled_ci*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
prev_workflow_run_id:
|
||||
description: 'previous workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
other_workflow_run_id:
|
||||
description: 'other workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
|
||||
# Used for `push` to easily modiffy the target workflow runs to compare against
|
||||
env:
|
||||
prev_workflow_run_id: ""
|
||||
other_workflow_run_id: ""
|
||||
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
|
||||
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: setup_values
|
||||
path: setup_values
|
||||
|
||||
model-ci:
|
||||
name: Model CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
|
||||
18
.github/workflows/slack-report.yml
vendored
18
.github/workflows/slack-report.yml
vendored
@ -39,6 +39,21 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
|
||||
- name: Prepare some setup values
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Send message to Slack
|
||||
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
|
||||
env:
|
||||
@ -50,7 +65,6 @@ jobs:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
CI_EVENT: ${{ inputs.ci_event }}
|
||||
CI_SHA: ${{ github.sha }}
|
||||
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
CI_TEST_JOB: ${{ inputs.job }}
|
||||
SETUP_STATUS: ${{ inputs.setup_status }}
|
||||
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
|
||||
@ -58,7 +72,6 @@ jobs:
|
||||
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
|
||||
# empty string, and the called script still get one argument (which is the emtpy string).
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
@ -86,7 +99,6 @@ jobs:
|
||||
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
|
||||
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
|
||||
@ -71,6 +71,9 @@ RUN python3 -m pip install --no-cache-dir g2p-en
|
||||
# For Some bitsandbytes tests
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# For Some tests with `@require_liger_kernel`
|
||||
RUN python3 -m pip install --no-cache-dir liger-kernel
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM rocm/dev-ubuntu-22.04:6.2.4
|
||||
FROM rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
||||
LABEL maintainer="Hugging Face"
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
@ -11,9 +11,6 @@ RUN apt update && \
|
||||
RUN git lfs install
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip numpy
|
||||
|
||||
RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade importlib-metadata setuptools ninja git+https://github.com/facebookresearch/detectron2.git pytesseract "itsdangerous<2.1.0"
|
||||
|
||||
ARG REF=main
|
||||
|
||||
@ -455,6 +455,8 @@
|
||||
title: Falcon
|
||||
- local: model_doc/falcon3
|
||||
title: Falcon3
|
||||
- local: model_doc/falcon_h1
|
||||
title: FalconH1
|
||||
- local: model_doc/falcon_mamba
|
||||
title: FalconMamba
|
||||
- local: model_doc/flan-t5
|
||||
|
||||
@ -125,4 +125,44 @@ would expect from a usual Python dictionary:
|
||||
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
||||
```
|
||||
|
||||
## Attention Mask Interface
|
||||
|
||||
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
|
||||
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
|
||||
the `AttentionInterface`:
|
||||
|
||||
```python
|
||||
from transformers import AttentionMaskInterface
|
||||
from transformers.masking_utils import sdpa_mask
|
||||
import torch
|
||||
|
||||
def my_new_sdpa_mask(*args, **kwargs):
|
||||
print("I just entered the attention mask computation")
|
||||
return sdpa_mask(*args, **kwargs)
|
||||
|
||||
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
|
||||
```
|
||||
|
||||
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
|
||||
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
|
||||
and `attention_mask=None` will be passed along to the Attention layers.
|
||||
|
||||
The default signature of the attention mask functions is the following:
|
||||
|
||||
```python
|
||||
def custom_attention_mask(
|
||||
batch_size: int, # required arg
|
||||
cache_position: torch.Tensor, # required arg
|
||||
kv_length: int, # required arg
|
||||
kv_offset: int = 0, # required arg
|
||||
mask_function: Callable = causal_mask_function, # required arg
|
||||
attention_mask: Optional[torch.Tensor] = None, # required arg
|
||||
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
|
||||
) -> Optional[torch.Tensor]:
|
||||
```
|
||||
|
||||
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
|
||||
|
||||
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
|
||||
@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Attention Mask Functions
|
||||
|
||||
[[autodoc]] AttentionMaskInterface
|
||||
- register
|
||||
|
||||
## Rotary Position Embedding Functions
|
||||
|
||||
[[autodoc]] dynamic_rope_update
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
@ -40,13 +41,13 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
@ -109,7 +110,7 @@ we saw the following speedups during inference.
|
||||
[[autodoc]] BioGptForCausalLM
|
||||
- forward
|
||||
|
||||
|
||||
|
||||
## BioGptForTokenClassification
|
||||
|
||||
[[autodoc]] BioGptForTokenClassification
|
||||
|
||||
@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
Note that [`BlenderbotSmallModel`] and
|
||||
@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI).
|
||||
|
||||
## Usage tips
|
||||
|
||||
Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
|
||||
|
||||
## Usage tips and example
|
||||
|
||||
Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
||||
Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
||||
rather than the left.
|
||||
|
||||
An example:
|
||||
@ -71,7 +73,7 @@ An example:
|
||||
`facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with
|
||||
[BlenderbotSmall](blenderbot-small).
|
||||
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
|
||||
65
docs/source/en/model_doc/falcon_h1.md
Normal file
65
docs/source/en/model_doc/falcon_h1.md
Normal file
@ -0,0 +1,65 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# FalconH1
|
||||
|
||||
## Overview
|
||||
|
||||
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
## Contributors
|
||||
|
||||
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
|
||||
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
|
||||
## FalconH1Config
|
||||
|
||||
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|
||||
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
|
||||
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
|
||||
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
|
||||
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
|
||||
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
|
||||
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
|
||||
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
|
||||
|
||||
|
||||
|
||||
[[autodoc]] FalconH1Config
|
||||
|
||||
<!---
|
||||
## Usage Tips
|
||||
Tips:
|
||||
- The architecture is based on Mamba-2 models.
|
||||
## FalconH1Model
|
||||
[[autodoc]] FalconH1Model
|
||||
- forward
|
||||
-->
|
||||
|
||||
## FalconH1ForCausalLM
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
|
||||
message = ["Mamba is a snake with following properties "]
|
||||
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||
response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
[[autodoc]] FalconH1ForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).
|
||||
@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
### Multi image inference
|
||||
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
@ -14,85 +14,124 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mamba
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Mamba
|
||||
|
||||
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
|
||||
[Mamba](https://huggingface.co/papers/2312.00752) is a selective structured state space model (SSMs) designed to work around Transformers computational inefficiency when dealing with long sequences. It is a completely attention-free architecture, and comprised of a combination of H3 and gated MLP blocks (Mamba block). Mamba's "content-based reasoning" allows it to focus on specific parts of an input depending on the current token. Mamba also uses a new hardware-aware parallel algorithm to compensate for the lack of convolutional operations. As a result, Mamba has fast inference and can scale to very long sequences.
|
||||
|
||||
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
|
||||
You can find all the original Mamba checkpoints under the [State Space Models](https://huggingface.co/state-spaces) organization.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
|
||||
> [!TIP]
|
||||
> Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks.
|
||||
|
||||
Tips:
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
||||
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
|
||||
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
|
||||
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
|
||||
- Contributions to make the naive path faster are welcome 🤗
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
|
||||
The original code can be found [here](https://github.com/state-spaces/mamba).
|
||||
|
||||
# Usage
|
||||
|
||||
### A simple generation example:
|
||||
```python
|
||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="state-spaces/mamba-130m-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16, device_map="auto",)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
### Peft finetuning
|
||||
The slow version is not very stable for training, and the fast one needs `float32`!
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
)
|
||||
trainer.train()
|
||||
```bash
|
||||
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model state-spaces/mamba-130m-hf --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bit integers.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
quantization_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto",)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
## Notes
|
||||
|
||||
- The current implementation uses the original CUDA kernels. The FlashAttention equivalent implementation is hosted in the [mamba-ssm](https://github.com/state-spaces/mamba) and [causal_conv1d](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports it!
|
||||
- Mamba stacks `mixer` layers which are equivalent to `Attention` layers. You can find the main logic of Mamba in the `MambaMixer` class.
|
||||
- The example below demonstrates how to fine-tune Mamba with [PEFT](https://huggingface.co/docs/peft).
|
||||
|
||||
```py
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## MambaConfig
|
||||
|
||||
[[autodoc]] MambaConfig
|
||||
|
||||
@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
|
||||
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
'Isto deve ir para o português.',
|
||||
'Y esto al español']
|
||||
```
|
||||
|
||||
@ -51,10 +51,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse
|
||||
|
||||
## Implementation differences with SwitchTransformers
|
||||
|
||||
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
|
||||
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
|
||||
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
|
||||
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
|
||||
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
|
||||
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
|
||||
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
|
||||
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
|
||||
|
||||
## Generating with NLLB-MoE
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
@ -29,7 +31,7 @@ on Java, Python and English.
|
||||
According to the abstract
|
||||
|
||||
*Code summarization and generation empower conversion between programming language (PL) and natural language (NL),
|
||||
while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
|
||||
while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
|
||||
a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks.
|
||||
PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding.
|
||||
Experiments on code summarization in the English language, code generation, and code translation in seven programming languages
|
||||
@ -50,7 +52,7 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
|
||||
|
||||
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
|
||||
|
||||
In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
|
||||
In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
|
||||
when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if
|
||||
it's passed with the `text_target` keyword argument.
|
||||
|
||||
|
||||
@ -14,59 +14,77 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Swin Transformer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Swin Transformer
|
||||
|
||||
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
|
||||
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
|
||||
[Swin Transformer](https://huggingface.co/papers/2103.14030) is a hierarchical vision transformer. Images are processed in patches and windowed self-attention is used to capture local information. These windows are shifted across the image to allow for cross-window connections, capturing global information more efficiently. This hierarchical approach with shifted windows allows the Swin Transformer to process images effectively at different scales and achieve linear computational complexity relative to image size, making it a versatile backbone for various vision tasks like image classification and object detection.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all official Swin Transformer checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=swin) organization.
|
||||
|
||||
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
|
||||
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
|
||||
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
|
||||
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
|
||||
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
|
||||
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
|
||||
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
|
||||
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
|
||||
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
|
||||
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
|
||||
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
|
||||
The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.*
|
||||
> [!TIP]
|
||||
> Click on the Swin Transformer models in the right sidebar for more examples of how to apply Swin Transformer to different image tasks.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
The example below demonstrates how to classify an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<small> Swin Transformer architecture. Taken from the <a href="https://arxiv.org/abs/2102.03334">original paper</a>.</small>
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Usage tips
|
||||
pipeline = pipeline(
|
||||
task="image-classification",
|
||||
model="microsoft/swin-tiny-patch4-window7-224",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
|
||||
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
## Resources
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Swin Transformer.
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
"microsoft/swin-tiny-patch4-window7-224",
|
||||
use_fast=True,
|
||||
)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"microsoft/swin-tiny-patch4-window7-224",
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = image_processor(image, return_tensors="pt").to("cuda")
|
||||
|
||||
- [`SwinForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
predicted_class_id = logits.argmax(dim=-1).item()
|
||||
|
||||
Besides that:
|
||||
class_labels = model.config.id2label
|
||||
predicted_class_label = class_labels[predicted_class_id]
|
||||
print(f"The predicted class label is: {predicted_class_label}")
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
- [`SwinForMaskedImageModeling`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
||||
## Notes
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- Swin can pad the inputs for any input height and width divisible by `32`.
|
||||
- Swin can be used as a [backbone](../backbones). When `output_hidden_states = True`, it outputs both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
|
||||
## SwinConfig
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ transcription[0]
|
||||
|
||||
## Notes
|
||||
|
||||
- Whisper relies on [`~GenerationMixin.generate`] for inference.
|
||||
- Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
|
||||
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
|
||||
|
||||
## WhisperConfig
|
||||
|
||||
@ -29,8 +29,6 @@
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: tasks/sequence_classification
|
||||
title: テキストの分類
|
||||
- local: tasks/token_classification
|
||||
title: トークンの分類
|
||||
- local: tasks/question_answering
|
||||
|
||||
@ -47,7 +47,7 @@ ALBERTモデルは、「[ALBERT: A Lite BERT for Self-supervised Learning of Lan
|
||||
|
||||
## 参考資料
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問応答タスクガイド](../tasks/question_answering)
|
||||
- [マスクされた言語モデルタスクガイド](../tasks/masked_language_modeling)
|
||||
|
||||
@ -129,7 +129,7 @@ BART を始めるのに役立つ公式 Hugging Face およびコミュニティ
|
||||
- [翻訳タスクガイド](../tasks/translation)
|
||||
|
||||
以下も参照してください。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [抽出されたチェックポイント](https://huggingface.co/models?search=distilbart) は、この [論文](https://arxiv.org/abs/2010.13002) で説明されています。
|
||||
|
||||
@ -76,7 +76,7 @@ BERT を始めるのに役立つ公式 Hugging Face およびコミュニティ
|
||||
- [`BertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
|
||||
- [`TFBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
|
||||
- [`FlaxBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb)。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
|
||||
|
||||
## ドキュメント リソース
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
||||
@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
|
||||
|
||||
## ドキュメント リソース
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [翻訳タスクガイド](../tasks/translation)
|
||||
|
||||
@ -39,7 +39,7 @@ BLOOM を使い始めるのに役立つ公式 Hugging Face およびコミュニ
|
||||
|
||||
以下も参照してください。
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ Bi-direction Encoders for Transformers (BERT) のフランス語版である Cam
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
||||
@ -98,7 +98,7 @@ CANINE は生の文字で動作するため、**トークナイザーなし**で
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [多肢選択タスク ガイド](../tasks/multiple_choice)
|
||||
|
||||
@ -53,7 +53,7 @@ ConvBERT トレーニングのヒントは BERT のヒントと似ています
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [マスクされた言語モデリング タスク ガイド](../tasks/masked_lang_modeling)
|
||||
|
||||
@ -61,7 +61,7 @@ CTRL モデルは、Nitish Shirish Keskar*、Bryan McCann*、Lav R. Varshney、C
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
||||
## CTRLConfig
|
||||
|
||||
@ -58,7 +58,7 @@ Data2Vec の使用を開始するのに役立つ公式 Hugging Face およびコ
|
||||
- カスタム データセットで [`TFData2VecVisionForImageClassification`] を微調整するには、[このノートブック](https://colab.research.google.com/github/sayakpaul/TF-2.0-Hacks/blob/master/data2vec_vision_image_classification.ipynb) を参照してください。 )。
|
||||
|
||||
**Data2VecText ドキュメント リソース**
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
||||
@ -61,7 +61,7 @@ v2 の新機能:
|
||||
[kamalkraj](https://huggingface.co/kamalkraj) による投稿。元のコードは [こちら](https://github.com/microsoft/DeBERTa) にあります。
|
||||
|
||||
## Resources
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [マスク言語モデリング タスク ガイド](../tasks/masked_language_modeling)
|
||||
|
||||
@ -52,7 +52,7 @@ DeBERTa を使い始めるのに役立つ公式 Hugging Face およびコミュ
|
||||
- DeBERTa による [機械学習によるスーパーチャージされた顧客サービス](https://huggingface.co/blog/supercharge-customer-service-with-machine-learning) に関するブログ投稿。
|
||||
- [`DebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
|
||||
- [`TFDebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
|
||||
<PipelineTag pipeline="token-classification" />
|
||||
|
||||
|
||||
@ -1,604 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Sequence classification
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
<Youtube id="dKE8SIt9C-w"/>
|
||||
|
||||
セマンティック セグメンテーションでは、画像の個々のピクセルにラベルまたはクラスを割り当てます。セグメンテーションにはいくつかのタイプがありますが、セマンティック セグメンテーションの場合、同じオブジェクトの一意のインスタンス間の区別は行われません。両方のオブジェクトに同じラベルが付けられます (たとえば、「car-1」と「car-2」の代わりに「car」)。セマンティック セグメンテーションの一般的な現実世界のアプリケーションには、歩行者や重要な交通情報を識別するための自動運転車のトレーニング、医療画像内の細胞と異常の識別、衛星画像からの環境変化の監視などが含まれます。
|
||||
|
||||
このガイドでは、次の方法を説明します。
|
||||
|
||||
1. [SceneParse150](https://huggingface.co/datasets/scene_parse_150) データセットの [SegFormer](https://huggingface.co/docs/transformers/main/en/model_doc/segformer#segformer) を微調整します。
|
||||
2. 微調整したモデルを推論に使用します。
|
||||
|
||||
<Tip>
|
||||
|
||||
このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/text-classification) を確認することをお勧めします。
|
||||
|
||||
</Tip>
|
||||
|
||||
始める前に、必要なライブラリがすべてインストールされていることを確認してください。
|
||||
|
||||
```bash
|
||||
pip install -q datasets transformers evaluate
|
||||
```
|
||||
|
||||
モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
## Load SceneParse150 dataset
|
||||
|
||||
|
||||
まず、SceneParse150 データセットの小さいサブセットを 🤗 データセット ライブラリから読み込みます。これにより、完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認する機会が得られます。
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> ds = load_dataset("scene_parse_150", split="train[:50]")
|
||||
```
|
||||
|
||||
[`~datasets.Dataset.train_test_split`] メソッドを使用して、データセットの `train` 分割をトレイン セットとテスト セットに分割します。
|
||||
|
||||
```py
|
||||
>>> ds = ds.train_test_split(test_size=0.2)
|
||||
>>> train_ds = ds["train"]
|
||||
>>> test_ds = ds["test"]
|
||||
```
|
||||
|
||||
次に、例を見てみましょう。
|
||||
|
||||
```py
|
||||
>>> train_ds[0]
|
||||
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
|
||||
'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
|
||||
'scene_category': 368}
|
||||
```
|
||||
|
||||
- `image`: シーンの PIL イメージ。
|
||||
- `annotation`: セグメンテーション マップの PIL イメージ。モデルのターゲットでもあります。
|
||||
- `scene_category`: 「キッチン」や「オフィス」などの画像シーンを説明するカテゴリ ID。このガイドでは、「image」と「annotation」のみが必要になります。どちらも PIL イメージです。
|
||||
|
||||
また、ラベル ID をラベル クラスにマップする辞書を作成することもできます。これは、後でモデルを設定するときに役立ちます。ハブからマッピングをダウンロードし、`id2label` および `label2id` ディクショナリを作成します。
|
||||
|
||||
```py
|
||||
>>> import json
|
||||
>>> from pathlib import Path
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
|
||||
>>> repo_id = "huggingface/label-files"
|
||||
>>> filename = "ade20k-id2label.json"
|
||||
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
|
||||
>>> id2label = {int(k): v for k, v in id2label.items()}
|
||||
>>> label2id = {v: k for k, v in id2label.items()}
|
||||
>>> num_labels = len(id2label)
|
||||
```
|
||||
|
||||
## Preprocess
|
||||
|
||||
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`do_reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> checkpoint = "nvidia/mit-b0"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||
```
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。このガイドでは、[torchvision](https://pytorch.org) の [`ColorJitter`](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) 関数を使用します。 /vision/stable/index.html) を使用して画像の色のプロパティをランダムに変更しますが、任意の画像ライブラリを使用することもできます。
|
||||
|
||||
```py
|
||||
>>> from torchvision.transforms import ColorJitter
|
||||
|
||||
>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
||||
```
|
||||
|
||||
次に、モデルの画像と注釈を準備するための 2 つの前処理関数を作成します。これらの関数は、画像を`pixel_values`に変換し、注釈を`labels`に変換します。トレーニング セットの場合、画像を画像プロセッサに提供する前に`jitter`が適用されます。テスト セットの場合、テスト中にデータ拡張が適用されないため、画像プロセッサは`images`を切り取って正規化し、`labels` のみを切り取ります。
|
||||
|
||||
```py
|
||||
>>> def train_transforms(example_batch):
|
||||
... images = [jitter(x) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
|
||||
|
||||
>>> def val_transforms(example_batch):
|
||||
... images = [x for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
```
|
||||
|
||||
データセット全体に`jitter`を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
|
||||
|
||||
```py
|
||||
>>> train_ds.set_transform(train_transforms)
|
||||
>>> test_ds.set_transform(val_transforms)
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。
|
||||
このガイドでは、[`tf.image`](https://www.tensorflow.org/api_docs/python/tf/image) を使用して画像の色のプロパティをランダムに変更しますが、任意のプロパティを使用することもできます。画像
|
||||
好きな図書館。
|
||||
2 つの別々の変換関数を定義します。
|
||||
- 画像拡張を含むトレーニング データ変換
|
||||
- 🤗 Transformers のコンピューター ビジョン モデルはチャネル優先のレイアウトを想定しているため、画像を転置するだけの検証データ変換
|
||||
|
||||
```py
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> def aug_transforms(image):
|
||||
... image = tf.keras.utils.img_to_array(image)
|
||||
... image = tf.image.random_brightness(image, 0.25)
|
||||
... image = tf.image.random_contrast(image, 0.5, 2.0)
|
||||
... image = tf.image.random_saturation(image, 0.75, 1.25)
|
||||
... image = tf.image.random_hue(image, 0.1)
|
||||
... image = tf.transpose(image, (2, 0, 1))
|
||||
... return image
|
||||
|
||||
|
||||
>>> def transforms(image):
|
||||
... image = tf.keras.utils.img_to_array(image)
|
||||
... image = tf.transpose(image, (2, 0, 1))
|
||||
... return image
|
||||
```
|
||||
|
||||
次に、モデルの画像と注釈のバッチを準備する 2 つの前処理関数を作成します。これらの機能が適用されます
|
||||
画像変換を行い、以前にロードされた `image_processor` を使用して画像を `pixel_values` に変換し、
|
||||
`labels`への注釈。 `ImageProcessor` は、画像のサイズ変更と正規化も処理します。
|
||||
|
||||
```py
|
||||
>>> def train_transforms(example_batch):
|
||||
... images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
|
||||
|
||||
>>> def val_transforms(example_batch):
|
||||
... images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
```
|
||||
|
||||
データセット全体に前処理変換を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。
|
||||
変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
|
||||
|
||||
```py
|
||||
>>> train_ds.set_transform(train_transforms)
|
||||
>>> test_ds.set_transform(val_transforms)
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Evaluate
|
||||
|
||||
トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[Mean Intersection over Union](https://huggingface.co/spaces/evaluate-metric/accuracy) (IoU) メトリックをロードします (🤗 Evaluate [クイック ツアー](https://huggingface.co) を参照してください) /docs/evaluate/a_quick_tour) を参照して、メトリクスをロードして計算する方法の詳細を確認してください)。
|
||||
|
||||
```py
|
||||
>>> import evaluate
|
||||
|
||||
>>> metric = evaluate.load("mean_iou")
|
||||
```
|
||||
|
||||
次に、メトリクスを [`~evaluate.EvaluationModule.compute`] する関数を作成します。予測を次のように変換する必要があります
|
||||
最初にロジットを作成し、次に [`~evaluate.EvaluationModule.compute`] を呼び出す前にラベルのサイズに一致するように再形成します。
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... with torch.no_grad():
|
||||
... logits, labels = eval_pred
|
||||
... logits_tensor = torch.from_numpy(logits)
|
||||
... logits_tensor = nn.functional.interpolate(
|
||||
... logits_tensor,
|
||||
... size=labels.shape[-2:],
|
||||
... mode="bilinear",
|
||||
... align_corners=False,
|
||||
... ).argmax(dim=1)
|
||||
|
||||
... pred_labels = logits_tensor.detach().cpu().numpy()
|
||||
... metrics = metric.compute(
|
||||
... predictions=pred_labels,
|
||||
... references=labels,
|
||||
... num_labels=num_labels,
|
||||
... ignore_index=255,
|
||||
... reduce_labels=False,
|
||||
... )
|
||||
... for key, value in metrics.items():
|
||||
... if type(value) is np.ndarray:
|
||||
... metrics[key] = value.tolist()
|
||||
... return metrics
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
```py
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... logits, labels = eval_pred
|
||||
... logits = tf.transpose(logits, perm=[0, 2, 3, 1])
|
||||
... logits_resized = tf.image.resize(
|
||||
... logits,
|
||||
... size=tf.shape(labels)[1:],
|
||||
... method="bilinear",
|
||||
... )
|
||||
|
||||
... pred_labels = tf.argmax(logits_resized, axis=-1)
|
||||
... metrics = metric.compute(
|
||||
... predictions=pred_labels,
|
||||
... references=labels,
|
||||
... num_labels=num_labels,
|
||||
... ignore_index=-1,
|
||||
... reduce_labels=image_processor.do_reduce_labels,
|
||||
... )
|
||||
|
||||
... per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
|
||||
... per_category_iou = metrics.pop("per_category_iou").tolist()
|
||||
|
||||
... metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
|
||||
... metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
|
||||
... return {"val_" + k: v for k, v in metrics.items()}
|
||||
```
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。
|
||||
|
||||
## Train
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
<Tip>
|
||||
|
||||
[`Trainer`] を使用したモデルの微調整に慣れていない場合は、[こちら](../training#finetune-with-trainer) の基本的なチュートリアルをご覧ください。
|
||||
|
||||
|
||||
</Tip>
|
||||
|
||||
これでモデルのトレーニングを開始する準備が整いました。 [`AutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、ラベル ID とラベル クラス間のマッピングをモデルに渡します。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
|
||||
```
|
||||
|
||||
この時点で残っている手順は次の 3 つだけです。
|
||||
|
||||
1. [`TrainingArguments`] でトレーニング ハイパーパラメータを定義します。 `image` 列が削除されるため、未使用の列を削除しないことが重要です。 `image` 列がないと、`pixel_values` を作成できません。この動作を防ぐには、`remove_unused_columns=False`を設定してください。他に必要なパラメータは、モデルの保存場所を指定する `output_dir` だけです。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[`Trainer`] は IoU メトリックを評価し、トレーニング チェックポイントを保存します。
|
||||
2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [`Trainer`] に渡します。
|
||||
3. [`~Trainer.train`] を呼び出してモデルを微調整します。
|
||||
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="segformer-b0-scene-parse-150",
|
||||
... learning_rate=6e-5,
|
||||
... num_train_epochs=50,
|
||||
... per_device_train_batch_size=2,
|
||||
... per_device_eval_batch_size=2,
|
||||
... save_total_limit=3,
|
||||
... eval_strategy="steps",
|
||||
... save_strategy="steps",
|
||||
... save_steps=20,
|
||||
... eval_steps=20,
|
||||
... logging_steps=1,
|
||||
... eval_accumulation_steps=5,
|
||||
... remove_unused_columns=False,
|
||||
... push_to_hub=True,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... train_dataset=train_ds,
|
||||
... eval_dataset=test_ds,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
トレーニングが完了したら、 [`~transformers.Trainer.push_to_hub`] メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。
|
||||
|
||||
```py
|
||||
>>> trainer.push_to_hub()
|
||||
```
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
<Tip>
|
||||
|
||||
Keras を使用したモデルの微調整に慣れていない場合は、まず [基本チュートリアル](./training#train-a-tensorflow-model-with-keras) を確認してください。
|
||||
|
||||
</Tip>
|
||||
|
||||
TensorFlow でモデルを微調整するには、次の手順に従います。
|
||||
1. トレーニングのハイパーパラメータを定義し、オプティマイザーと学習率スケジュールを設定します。
|
||||
2. 事前トレーニングされたモデルをインスタンス化します。
|
||||
3. 🤗 データセットを `tf.data.Dataset` に変換します。
|
||||
4. モデルをコンパイルします。
|
||||
5. コールバックを追加してメトリクスを計算し、モデルを 🤗 Hub にアップロードします
|
||||
6. `fit()` メソッドを使用してトレーニングを実行します。
|
||||
|
||||
まず、ハイパーパラメーター、オプティマイザー、学習率スケジュールを定義します。
|
||||
|
||||
|
||||
```py
|
||||
>>> from transformers import create_optimizer
|
||||
|
||||
>>> batch_size = 2
|
||||
>>> num_epochs = 50
|
||||
>>> num_train_steps = len(train_ds) * num_epochs
|
||||
>>> learning_rate = 6e-5
|
||||
>>> weight_decay_rate = 0.01
|
||||
|
||||
>>> optimizer, lr_schedule = create_optimizer(
|
||||
... init_lr=learning_rate,
|
||||
... num_train_steps=num_train_steps,
|
||||
... weight_decay_rate=weight_decay_rate,
|
||||
... num_warmup_steps=0,
|
||||
... )
|
||||
```
|
||||
|
||||
次に、ラベル マッピングとともに [`TFAutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、それをコンパイルします。
|
||||
オプティマイザ。 Transformers モデルにはすべてデフォルトのタスク関連の損失関数があるため、次の場合を除き、損失関数を指定する必要はないことに注意してください。
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForSemanticSegmentation
|
||||
|
||||
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
|
||||
... checkpoint,
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... )
|
||||
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||
```
|
||||
|
||||
[`~datasets.Dataset.to_tf_dataset`] と [`DefaultDataCollator`] を使用して、データセットを `tf.data.Dataset` 形式に変換します。
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator(return_tensors="tf")
|
||||
|
||||
>>> tf_train_dataset = train_ds.to_tf_dataset(
|
||||
... columns=["pixel_values", "label"],
|
||||
... shuffle=True,
|
||||
... batch_size=batch_size,
|
||||
... collate_fn=data_collator,
|
||||
... )
|
||||
|
||||
>>> tf_eval_dataset = test_ds.to_tf_dataset(
|
||||
... columns=["pixel_values", "label"],
|
||||
... shuffle=True,
|
||||
... batch_size=batch_size,
|
||||
... collate_fn=data_collator,
|
||||
... )
|
||||
```
|
||||
|
||||
予測から精度を計算し、モデルを 🤗 ハブにプッシュするには、[Keras callbacks](../main_classes/keras_callbacks) を使用します。
|
||||
`compute_metrics` 関数を [`KerasMetricCallback`] に渡します。
|
||||
そして [`PushToHubCallback`] を使用してモデルをアップロードします。
|
||||
|
||||
```py
|
||||
>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
|
||||
>>> metric_callback = KerasMetricCallback(
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
||||
ついに、モデルをトレーニングする準備が整いました。`fit()`トレーニングおよび検証データセット、エポック数、
|
||||
モデルを微調整するためのコールバック:
|
||||
|
||||
```py
|
||||
>>> model.fit(
|
||||
... tf_train_dataset,
|
||||
... validation_data=tf_eval_dataset,
|
||||
... callbacks=callbacks,
|
||||
... epochs=num_epochs,
|
||||
... )
|
||||
```
|
||||
|
||||
おめでとう!モデルを微調整し、🤗 Hub で共有しました。これで推論に使用できるようになりました。
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
モデルを微調整したので、それを推論に使用できるようになりました。
|
||||
|
||||
推論のために画像をロードします。
|
||||
|
||||
```py
|
||||
>>> image = ds[0]["image"]
|
||||
>>> image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-image.png" alt="Image of bedroom"/>
|
||||
</div>
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
推論用に微調整されたモデルを試す最も簡単な方法は、それを [`pipeline`] で使用することです。モデルを使用して画像セグメンテーション用の `pipeline` をインスタンス化し、それに画像を渡します。
|
||||
|
||||
```py
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> segmenter = pipeline("image-segmentation", model="my_awesome_seg_model")
|
||||
>>> segmenter(image)
|
||||
[{'score': None,
|
||||
'label': 'wall',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062690>},
|
||||
{'score': None,
|
||||
'label': 'sky',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A50>},
|
||||
{'score': None,
|
||||
'label': 'floor',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062B50>},
|
||||
{'score': None,
|
||||
'label': 'ceiling',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A10>},
|
||||
{'score': None,
|
||||
'label': 'bed ',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E90>},
|
||||
{'score': None,
|
||||
'label': 'windowpane',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062390>},
|
||||
{'score': None,
|
||||
'label': 'cabinet',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062550>},
|
||||
{'score': None,
|
||||
'label': 'chair',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062D90>},
|
||||
{'score': None,
|
||||
'label': 'armchair',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E10>}]
|
||||
```
|
||||
|
||||
必要に応じて、`pipeline` の結果を手動で複製することもできます。画像プロセッサで画像を処理し、`pixel_values`を GPU に配置します。
|
||||
|
||||
```py
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use GPU if available, otherwise use a CPU
|
||||
>>> encoding = image_processor(image, return_tensors="pt")
|
||||
>>> pixel_values = encoding.pixel_values.to(device)
|
||||
```
|
||||
|
||||
入力をモデルに渡し、「logits」を返します。
|
||||
|
||||
```py
|
||||
>>> outputs = model(pixel_values=pixel_values)
|
||||
>>> logits = outputs.logits.cpu()
|
||||
```
|
||||
|
||||
次に、ロジットを元の画像サイズに再スケールします。
|
||||
|
||||
|
||||
```py
|
||||
>>> upsampled_logits = nn.functional.interpolate(
|
||||
... logits,
|
||||
... size=image.size[::-1],
|
||||
... mode="bilinear",
|
||||
... align_corners=False,
|
||||
... )
|
||||
|
||||
>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
画像プロセッサをロードして画像を前処理し、入力を TensorFlow テンソルとして返します。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
|
||||
>>> inputs = image_processor(image, return_tensors="tf")
|
||||
```
|
||||
|
||||
入力をモデルに渡し、`logits`を返します。
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForSemanticSegmentation
|
||||
|
||||
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
|
||||
>>> logits = model(**inputs).logits
|
||||
```
|
||||
|
||||
次に、ロジットを元の画像サイズに再スケールし、クラス次元に argmax を適用します。
|
||||
|
||||
```py
|
||||
>>> logits = tf.transpose(logits, [0, 2, 3, 1])
|
||||
|
||||
>>> upsampled_logits = tf.image.resize(
|
||||
... logits,
|
||||
... # We reverse the shape of `image` because `image.size` returns width and height.
|
||||
... image.size[::-1],
|
||||
... )
|
||||
|
||||
>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]
|
||||
```
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
結果を視覚化するには、[データセット カラー パレット](https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51) を、それぞれをマップする `ade_palette()` としてロードします。クラスを RGB 値に変換します。次に、画像と予測されたセグメンテーション マップを組み合わせてプロットできます。
|
||||
|
||||
```py
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
|
||||
>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
|
||||
>>> palette = np.array(ade_palette())
|
||||
>>> for label, color in enumerate(palette):
|
||||
... color_seg[pred_seg == label, :] = color
|
||||
>>> color_seg = color_seg[..., ::-1] # convert to BGR
|
||||
|
||||
>>> img = np.array(image) * 0.5 + color_seg * 0.5 # plot the image with the segmentation map
|
||||
>>> img = img.astype(np.uint8)
|
||||
|
||||
>>> plt.figure(figsize=(15, 10))
|
||||
>>> plt.imshow(img)
|
||||
>>> plt.show()
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-preds.png" alt="Image of bedroom overlaid with segmentation map"/>
|
||||
</div>
|
||||
@ -221,7 +221,7 @@ Transformerは最初に機械翻訳のために設計され、それ以降、ほ
|
||||
|
||||
事前訓練済みモデルをテキスト分類に使用するには、ベースのBERTモデルの上にシーケンス分類ヘッドを追加します。シーケンス分類ヘッドは最終的な隠れた状態を受け入れ、それらをロジットに変換するための線形層です。クロスエントロピー損失は、ロジットとターゲット間で最も可能性の高いラベルを見つけるために計算されます。
|
||||
|
||||
テキスト分類を試してみる準備はできましたか?DistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド](tasks/sequence_classification)をチェックしてみてください!
|
||||
テキスト分類を試してみる準備はできましたか?DistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド(英語版)](../en/tasks/sequence_classification)をチェックしてみてください!
|
||||
|
||||
### Token classification
|
||||
|
||||
|
||||
@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
||||
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
||||
4
examples/metrics-monitoring/README.md
Normal file
4
examples/metrics-monitoring/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Metrics Monitoring
|
||||
|
||||
## Continuous Batching Metrics in Transformers
|
||||
|
||||
974
examples/metrics-monitoring/continuous-batching-dashboard.json
Normal file
974
examples/metrics-monitoring/continuous-batching-dashboard.json
Normal file
@ -0,0 +1,974 @@
|
||||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": {
|
||||
"type": "grafana",
|
||||
"uid": "-- Grafana --"
|
||||
},
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"target": {
|
||||
"limit": 100,
|
||||
"matchAny": false,
|
||||
"tags": [],
|
||||
"type": "dashboard"
|
||||
},
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 0,
|
||||
"id": 2,
|
||||
"links": [],
|
||||
"panels": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"description": "Memory usage of the PagedAttentionCache",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"mappings": [],
|
||||
"max": 10737418240,
|
||||
"min": 0,
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "yellow",
|
||||
"value": 5368709120
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 8589934592
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "bytes"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 6,
|
||||
"x": 0,
|
||||
"y": 0
|
||||
},
|
||||
"id": 2,
|
||||
"options": {
|
||||
"minVizHeight": 75,
|
||||
"minVizWidth": 75,
|
||||
"orientation": "auto",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showThresholdLabels": false,
|
||||
"showThresholdMarkers": true,
|
||||
"sizing": "auto"
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "kv_cache_memory_bytes",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "KV Cache Memory Usage",
|
||||
"transparent": true,
|
||||
"type": "gauge"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "dark-blue"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 6,
|
||||
"x": 6,
|
||||
"y": 0
|
||||
},
|
||||
"id": 13,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "auto",
|
||||
"percentChangeColorMode": "standard",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showPercentChange": false,
|
||||
"textMode": "auto",
|
||||
"wideLayout": true
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "active_requests_count",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Active Requests",
|
||||
"transparent": true,
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "dark-orange"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 6,
|
||||
"x": 12,
|
||||
"y": 0
|
||||
},
|
||||
"id": 14,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "area",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "auto",
|
||||
"percentChangeColorMode": "standard",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showPercentChange": false,
|
||||
"textMode": "auto",
|
||||
"wideLayout": true
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "waiting_requests_count",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Waiting Requests",
|
||||
"transparent": true,
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"description": "Ratio of decode tokens to prefill tokens in a batch",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "blue"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 6,
|
||||
"x": 18,
|
||||
"y": 0
|
||||
},
|
||||
"id": 6,
|
||||
"options": {
|
||||
"colorMode": "value",
|
||||
"graphMode": "none",
|
||||
"justifyMode": "auto",
|
||||
"orientation": "auto",
|
||||
"percentChangeColorMode": "standard",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showPercentChange": false,
|
||||
"textMode": "auto",
|
||||
"wideLayout": true
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "decode_prefill_ratio",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Decode/Prefill Ratio",
|
||||
"transparent": true,
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"barWidthFactor": 0.6,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 8
|
||||
},
|
||||
"id": 10,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"hideZeros": false,
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"editorMode": "code",
|
||||
"expr": "rate(decode_tokens_processed_total[$__rate_interval])",
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Decode tokens throupught tok/s",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"barWidthFactor": 0.6,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 8
|
||||
},
|
||||
"id": 11,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"hideZeros": false,
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"editorMode": "code",
|
||||
"expr": "rate(prefill_tokens_processed_total[$__rate_interval])",
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Prefill rate tok/s",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"barWidthFactor": 0.6,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 16
|
||||
},
|
||||
"id": 9,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"hideZeros": false,
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.95, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
|
||||
"legendFormat": "p95",
|
||||
"range": true,
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.99, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
|
||||
"hide": false,
|
||||
"instant": false,
|
||||
"legendFormat": "p99",
|
||||
"range": true,
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.5, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))",
|
||||
"hide": false,
|
||||
"instant": false,
|
||||
"legendFormat": "p50",
|
||||
"range": true,
|
||||
"refId": "C"
|
||||
}
|
||||
],
|
||||
"title": "Batch fill percentage percentiles",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"description": "KV Cache Memory Usage Over Time",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"barWidthFactor": 0.6,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 20,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "bytes"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 16
|
||||
},
|
||||
"id": 4,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"hideZeros": false,
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "kv_cache_memory_bytes",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "Used memory",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "kv_cache_free_memory_bytes",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"legendFormat": "free memory",
|
||||
"range": true,
|
||||
"refId": "B",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "KV Cache Memory Usage Over Time",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "thresholds"
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "ms"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 24
|
||||
},
|
||||
"id": 8,
|
||||
"options": {
|
||||
"displayMode": "gradient",
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": false
|
||||
},
|
||||
"maxVizHeight": 300,
|
||||
"minVizHeight": 10,
|
||||
"minVizWidth": 0,
|
||||
"namePlacement": "auto",
|
||||
"orientation": "auto",
|
||||
"reduceOptions": {
|
||||
"calcs": [
|
||||
"lastNotNull"
|
||||
],
|
||||
"fields": "",
|
||||
"values": false
|
||||
},
|
||||
"showUnfilled": true,
|
||||
"sizing": "auto",
|
||||
"valueMode": "color"
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "histogram_quantile(0.95, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "p95",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "histogram_quantile(0.5, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": true,
|
||||
"legendFormat": "p50",
|
||||
"range": true,
|
||||
"refId": "B",
|
||||
"useBackend": false
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "histogram_quantile(0.99, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": false,
|
||||
"instant": false,
|
||||
"legendFormat": "p99",
|
||||
"range": true,
|
||||
"refId": "C",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Time to First Token (TTFT)",
|
||||
"type": "bargauge"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"barWidthFactor": 0.6,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green"
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
"unit": "ms"
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 24
|
||||
},
|
||||
"id": 12,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"hideZeros": false,
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"pluginVersion": "12.0.0",
|
||||
"targets": [
|
||||
{
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.5, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
|
||||
"legendFormat": "p50",
|
||||
"range": true,
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.95, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
|
||||
"hide": false,
|
||||
"instant": false,
|
||||
"legendFormat": "p95",
|
||||
"range": true,
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"editorMode": "code",
|
||||
"expr": "histogram_quantile(0.99, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))",
|
||||
"hide": false,
|
||||
"instant": false,
|
||||
"legendFormat": "p99",
|
||||
"range": true,
|
||||
"refId": "C"
|
||||
}
|
||||
],
|
||||
"title": "Request latency percentiles",
|
||||
"type": "timeseries"
|
||||
}
|
||||
],
|
||||
"preload": false,
|
||||
"refresh": "5s",
|
||||
"schemaVersion": 41,
|
||||
"tags": [],
|
||||
"templating": {
|
||||
"list": []
|
||||
},
|
||||
"time": {
|
||||
"from": "now-15m",
|
||||
"to": "now"
|
||||
},
|
||||
"timepicker": {},
|
||||
"timezone": "",
|
||||
"title": "Transformers Continuous Batching Metrics",
|
||||
"uid": "Lw6CTvVSz",
|
||||
"version": 5
|
||||
}
|
||||
55
examples/metrics-monitoring/docker-compose.yml
Normal file
55
examples/metrics-monitoring/docker-compose.yml
Normal file
@ -0,0 +1,55 @@
|
||||
services:
|
||||
memcached:
|
||||
image: memcached:1.6.29
|
||||
container_name: memcached
|
||||
ports:
|
||||
- "11211:11211"
|
||||
environment:
|
||||
- MEMCACHED_MAX_MEMORY=64m # Set the maximum memory usage
|
||||
- MEMCACHED_THREADS=4 # Number of threads to use
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
command:
|
||||
- "--config.file=/etc/prometheus/prometheus.yml"
|
||||
- --web.enable-otlp-receiver # Enable OTLP receiver
|
||||
- --web.enable-remote-write-receiver
|
||||
- --enable-feature=exemplar-storage
|
||||
- --enable-feature=native-histograms
|
||||
volumes:
|
||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
ports:
|
||||
- "9090:9090"
|
||||
|
||||
tempo:
|
||||
image: grafana/tempo:latest
|
||||
command: [ "-config.file=/etc/tempo.yaml" ]
|
||||
volumes:
|
||||
- ./tempo.yaml:/etc/tempo.yaml
|
||||
ports:
|
||||
- "14268:14268" # jaeger ingest
|
||||
- "3200:3200" # tempo
|
||||
- "9095:9095" # tempo grpc
|
||||
- "4317:4317" # otlp grpc
|
||||
- "4318:4318" # otlp http
|
||||
- "9411:9411" # zipkin
|
||||
depends_on:
|
||||
- memcached
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
volumes:
|
||||
- ./continuous-batching-dashboard.json:/etc/grafana/provisioning/dashboards/continuous-batching-dashboard.json
|
||||
- ./grafana-dashboard.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboard.yaml
|
||||
- ./grafana-datasources.yaml:/etc/grafana/provisioning/datasources/datasources.yaml
|
||||
environment:
|
||||
- GF_AUTH_ANONYMOUS_ENABLED=true
|
||||
- GF_AUTH_ANONYMOUS_ORG_ROLE=Admin
|
||||
- GF_AUTH_DISABLE_LOGIN_FORM=true
|
||||
- GF_FEATURE_TOGGLES_ENABLE=traceqlEditor metricsSummary
|
||||
- GF_INSTALL_PLUGINS=https://storage.googleapis.com/integration-artifacts/grafana-exploretraces-app/grafana-exploretraces-app-latest.zip;grafana-traces-app
|
||||
ports:
|
||||
- "3000:3000"
|
||||
depends_on:
|
||||
- prometheus
|
||||
- tempo
|
||||
11
examples/metrics-monitoring/grafana-dashboard.yaml
Normal file
11
examples/metrics-monitoring/grafana-dashboard.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
apiVersion: 1
|
||||
|
||||
providers:
|
||||
- name: 'Transformers Dashboards'
|
||||
orgId: 1
|
||||
folder: 'Transformers'
|
||||
type: file
|
||||
disableDeletion: false
|
||||
editable: true
|
||||
options:
|
||||
path: /etc/grafana/provisioning/dashboards
|
||||
14
examples/metrics-monitoring/grafana-datasources.yaml
Normal file
14
examples/metrics-monitoring/grafana-datasources.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
apiVersion: 1
|
||||
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
access: proxy
|
||||
url: http://prometheus:9090
|
||||
isDefault: true
|
||||
|
||||
- name: Tempo
|
||||
type: tempo
|
||||
access: proxy
|
||||
url: http://tempo:3200
|
||||
uid: tempo
|
||||
48
examples/metrics-monitoring/metrics_example.py
Normal file
48
examples/metrics-monitoring/metrics_example.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Example usage of the trace and attach_tracer decorators
|
||||
|
||||
from transformers.utils.metrics import attach_tracer, traced
|
||||
|
||||
|
||||
@attach_tracer()
|
||||
class ExampleClass:
|
||||
def __init__(self, name):
|
||||
# The attach_tracer decorator has already created self.tracer for us
|
||||
self.name = name
|
||||
|
||||
@traced # This method will use the tracer from the class instance
|
||||
def process_data(self, data):
|
||||
# This method is traced and can use self.tracer
|
||||
return f"Processed {data} with {self.name}"
|
||||
|
||||
@traced(span_name="custom_operation") # With custom span name
|
||||
def special_operation(self, value):
|
||||
# Also traced, with a custom span name
|
||||
return value * 2
|
||||
|
||||
@traced(
|
||||
additional_attributes=[
|
||||
("name", "object.name", lambda x: x.upper()), # Using a transform function
|
||||
("name", "object.fixed_value", "static_value"), # Using a fixed value
|
||||
]
|
||||
)
|
||||
def operation_with_attributes(self):
|
||||
# This will add the specified attributes to the span
|
||||
return "Operation completed"
|
||||
|
||||
|
||||
# For functions without a class, the traced decorator still works
|
||||
@traced
|
||||
def standalone_function(arg1, arg2):
|
||||
# For functions, a tracer is created based on the module name
|
||||
return arg1 + arg2
|
||||
|
||||
|
||||
# Usage:
|
||||
if __name__ == "__main__":
|
||||
# With OpenTelemetry configured, these will produce traces
|
||||
example = ExampleClass("test_object")
|
||||
example.process_data("sample")
|
||||
example.special_operation(42)
|
||||
example.operation_with_attributes()
|
||||
|
||||
result = standalone_function(1, 2)
|
||||
3
examples/metrics-monitoring/prometheus.yml
Normal file
3
examples/metrics-monitoring/prometheus.yml
Normal file
@ -0,0 +1,3 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
90
examples/metrics-monitoring/tempo.yaml
Normal file
90
examples/metrics-monitoring/tempo.yaml
Normal file
@ -0,0 +1,90 @@
|
||||
stream_over_http_enabled: true
|
||||
server:
|
||||
http_listen_port: 3200
|
||||
log_level: info
|
||||
|
||||
|
||||
cache:
|
||||
background:
|
||||
writeback_goroutines: 5
|
||||
caches:
|
||||
- roles:
|
||||
- frontend-search
|
||||
memcached:
|
||||
addresses: dns+memcached:11211
|
||||
|
||||
query_frontend:
|
||||
search:
|
||||
duration_slo: 5s
|
||||
throughput_bytes_slo: 1.073741824e+09
|
||||
metadata_slo:
|
||||
duration_slo: 5s
|
||||
throughput_bytes_slo: 1.073741824e+09
|
||||
trace_by_id:
|
||||
duration_slo: 100ms
|
||||
metrics:
|
||||
max_duration: 200h # maximum duration of a metrics query, increase for local setups
|
||||
query_backend_after: 5m
|
||||
duration_slo: 5s
|
||||
throughput_bytes_slo: 1.073741824e+09
|
||||
|
||||
distributor:
|
||||
receivers: # this configuration will listen on all ports and protocols that tempo is capable of.
|
||||
jaeger: # the receives all come from the OpenTelemetry collector. more configuration information can
|
||||
protocols: # be found there: https://github.com/open-telemetry/opentelemetry-collector/tree/main/receiver
|
||||
thrift_http: #
|
||||
endpoint: "tempo:14268" # for a production deployment you should only enable the receivers you need!
|
||||
grpc:
|
||||
endpoint: "tempo:14250"
|
||||
thrift_binary:
|
||||
endpoint: "tempo:6832"
|
||||
thrift_compact:
|
||||
endpoint: "tempo:6831"
|
||||
zipkin:
|
||||
endpoint: "tempo:9411"
|
||||
otlp:
|
||||
protocols:
|
||||
grpc:
|
||||
endpoint: "tempo:4317"
|
||||
http:
|
||||
endpoint: "tempo:4318"
|
||||
opencensus:
|
||||
endpoint: "tempo:55678"
|
||||
|
||||
ingester:
|
||||
max_block_duration: 5m # cut the headblock when this much time passes. this is being set for demo purposes and should probably be left alone normally
|
||||
|
||||
compactor:
|
||||
compaction:
|
||||
block_retention: 720h # overall Tempo trace retention. set for demo purposes
|
||||
|
||||
metrics_generator:
|
||||
registry:
|
||||
external_labels:
|
||||
source: tempo
|
||||
cluster: docker-compose
|
||||
storage:
|
||||
path: /var/tempo/generator/wal
|
||||
remote_write:
|
||||
- url: http://prometheus:9090/api/v1/write
|
||||
send_exemplars: true
|
||||
traces_storage:
|
||||
path: /var/tempo/generator/traces
|
||||
processor:
|
||||
local_blocks:
|
||||
filter_server_spans: false
|
||||
flush_to_storage: true
|
||||
|
||||
storage:
|
||||
trace:
|
||||
backend: local # backend configuration to use
|
||||
wal:
|
||||
path: /var/tempo/wal # where to store the wal locally
|
||||
local:
|
||||
path: /var/tempo/blocks
|
||||
|
||||
overrides:
|
||||
defaults:
|
||||
metrics_generator:
|
||||
processors: [service-graphs, span-metrics, local-blocks] # enables metrics generator
|
||||
generate_native_histograms: both
|
||||
@ -1,3 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
||||
109
examples/pytorch/continuous_batching.py
Normal file
109
examples/pytorch/continuous_batching.py
Normal file
@ -0,0 +1,109 @@
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=512,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=2048,
|
||||
block_size=128,
|
||||
do_sample=True,
|
||||
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
|
||||
scheduler="prefill_first",
|
||||
)
|
||||
|
||||
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
|
||||
# --- Example 1: Simple Version using generate_batch ---
|
||||
print("--- Running CB Generation Example ---")
|
||||
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["question"])
|
||||
|
||||
|
||||
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
|
||||
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||
|
||||
start_time_simple = time.time()
|
||||
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
||||
batch_outputs = model.generate_batch(
|
||||
inputs=simple_batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
end_time_simple = time.time()
|
||||
|
||||
for request in batch_outputs:
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
try:
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||
except Exception as e:
|
||||
print(f"Decoding failed for request {request}: {e}")
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
print(f"{request} Input: {input_text}")
|
||||
print(f"{request} Output: {output_text}")
|
||||
else:
|
||||
print("", end="\r\r\r\r")
|
||||
print("-" * 20)
|
||||
print("--- Finished CB Generation Example ---\n\n")
|
||||
|
||||
|
||||
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
|
||||
|
||||
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
||||
|
||||
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512)
|
||||
# simple_batch_inputs = list(tokenized_test_prompts["input_ids"])
|
||||
|
||||
# def tokenize_function(examples):
|
||||
# # Truncate to avoid overly long prompts exceeding max context length
|
||||
# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512)
|
||||
|
||||
|
||||
# tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
|
||||
# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||
|
||||
|
||||
# model.config.attn_implementation = "sdpa"
|
||||
# start_time_simple = time.time()
|
||||
# batch_size = 64
|
||||
# full_outputs = []
|
||||
# from tqdm import tqdm
|
||||
|
||||
# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)):
|
||||
# outputs = model.generate(
|
||||
# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device),
|
||||
# generation_config=GenerationConfig(
|
||||
# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
|
||||
# ),
|
||||
# )
|
||||
# full_outputs.extend(outputs.tolist())
|
||||
|
||||
# end_time_simple = time.time()
|
||||
# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
|
||||
# print("\nResults from simple generate_batch:")
|
||||
# for i, request in enumerate(full_outputs):
|
||||
# output_text = tokenizer.decode(request, skip_special_tokens=False)
|
||||
# print("-" * 20)
|
||||
# print(f" Output: {output_text}")
|
||||
# print("-" * 20)
|
||||
# print("--- Finished Simple Batch Generation Example ---\n\n")
|
||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
||||
14
setup.py
14
setup.py
@ -117,7 +117,7 @@ _deps = [
|
||||
"GitPython<3.1.19",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"hf_xet",
|
||||
"huggingface-hub>=0.30.0,<1.0",
|
||||
"huggingface-hub==v0.32.0.rc1",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
@ -125,7 +125,7 @@ _deps = [
|
||||
"jaxlib>=0.4.1,<=0.4.13",
|
||||
"jieba",
|
||||
"jinja2>=3.1.0",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
@ -201,6 +201,9 @@ _deps = [
|
||||
"pytest-rich",
|
||||
"libcst",
|
||||
"rich",
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp",
|
||||
"opentelemetry-sdk",
|
||||
]
|
||||
|
||||
|
||||
@ -315,7 +318,7 @@ extras["audio"] = deps_list(
|
||||
"librosa",
|
||||
"pyctcdecode",
|
||||
"phonemizer",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
)
|
||||
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
|
||||
extras["speech"] = deps_list("torchaudio") + extras["audio"]
|
||||
@ -435,6 +438,9 @@ extras["torchhub"] = deps_list(
|
||||
|
||||
extras["benchmark"] = deps_list("optimum-benchmark")
|
||||
|
||||
# OpenTelemetry dependencies for metrics collection in continuous batching
|
||||
extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk")
|
||||
|
||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
install_requires = [
|
||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||
@ -451,7 +457,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user