More informative error when using Transformers backend (#16988)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-24 03:54:03 +01:00
committed by GitHub
parent ed50f46641
commit 2c8ed8ee48
2 changed files with 38 additions and 34 deletions

View File

@ -40,33 +40,37 @@ You can force the use of `TransformersForCausalLM` by setting `model_impl="trans
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
:::
#### Supported features
#### Custom models
The Transformers modeling backend explicitly supports the following features:
If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
- <project:#quantization-index> (except GGUF)
- <project:#lora-adapter>
- <project:#distributed-serving>
For a model to be compatible with the Transformers backend for vLLM it must:
#### Remote Code
- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)):
* The model directory must have the correct structure (e.g. `config.json` is present).
* `config.json` must contain `auto_map.AutoModel`.
- be a Transformers backend for vLLM compatible model (see <project:#writing-custom-models>):
* Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`).
If your model is neither supported natively by vLLM or Transformers, you can still run it in vLLM!
If the compatible model is:
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for <project:#offline-inference> or `--trust-remode-code` for the <project:#openai-compatible-server>.
- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for <project:#offline-inference> or `vllm serve <MODEL_DIR>` for the <project:#openai-compatible-server>.
:::{tip}
If you have not yet created your custom model, you can follow this guide on [customising models in Transformers](https://huggingface.co/docs/transformers/en/custom_models).
:::
This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM!
```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
```
(writing-custom-models)=
#### Writing custom models
This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)).
To make your model compatible with the Transformers backend, it needs:
1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
3. `MyModel` must contain `_supports_attention_backend = True`.
```{code-block} python
:caption: modeling_my_model.py
@ -75,7 +79,7 @@ from torch import nn
class MyAttention(nn.Module):
def forward(self, hidden_states, **kwargs): # <- kwargs are required
def forward(self, hidden_states, **kwargs):
...
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
@ -91,11 +95,11 @@ class MyModel(PreTrainedModel):
_supports_attention_backend = True
```
Here is what happens in the background:
Here is what happens in the background when this model is loaded:
1. The config is loaded
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
That's it!

View File

@ -30,15 +30,6 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def is_transformers_impl_compatible(
arch: str,
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
return mod.is_backend_compatible()
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
@ -61,17 +52,26 @@ def resolve_transformers_arch(model_config: ModelConfig,
revision=model_config.revision)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
custom_model_module = auto_modules.get("AutoModel")
model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module):
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "