mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
More informative error when using Transformers backend (#16988)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -40,33 +40,37 @@ You can force the use of `TransformersForCausalLM` by setting `model_impl="trans
|
||||
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
|
||||
:::
|
||||
|
||||
#### Supported features
|
||||
#### Custom models
|
||||
|
||||
The Transformers modeling backend explicitly supports the following features:
|
||||
If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
|
||||
|
||||
- <project:#quantization-index> (except GGUF)
|
||||
- <project:#lora-adapter>
|
||||
- <project:#distributed-serving>
|
||||
For a model to be compatible with the Transformers backend for vLLM it must:
|
||||
|
||||
#### Remote Code
|
||||
- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)):
|
||||
* The model directory must have the correct structure (e.g. `config.json` is present).
|
||||
* `config.json` must contain `auto_map.AutoModel`.
|
||||
- be a Transformers backend for vLLM compatible model (see <project:#writing-custom-models>):
|
||||
* Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`).
|
||||
|
||||
If your model is neither supported natively by vLLM or Transformers, you can still run it in vLLM!
|
||||
If the compatible model is:
|
||||
|
||||
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
|
||||
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
|
||||
- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for <project:#offline-inference> or `--trust-remode-code` for the <project:#openai-compatible-server>.
|
||||
- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for <project:#offline-inference> or `vllm serve <MODEL_DIR>` for the <project:#openai-compatible-server>.
|
||||
|
||||
:::{tip}
|
||||
If you have not yet created your custom model, you can follow this guide on [customising models in Transformers](https://huggingface.co/docs/transformers/en/custom_models).
|
||||
:::
|
||||
This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM!
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
(writing-custom-models)=
|
||||
|
||||
#### Writing custom models
|
||||
|
||||
This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)).
|
||||
|
||||
To make your model compatible with the Transformers backend, it needs:
|
||||
|
||||
1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
|
||||
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
|
||||
3. `MyModel` must contain `_supports_attention_backend = True`.
|
||||
|
||||
```{code-block} python
|
||||
:caption: modeling_my_model.py
|
||||
|
||||
@ -75,7 +79,7 @@ from torch import nn
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, **kwargs): # <- kwargs are required
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
...
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -91,11 +95,11 @@ class MyModel(PreTrainedModel):
|
||||
_supports_attention_backend = True
|
||||
```
|
||||
|
||||
Here is what happens in the background:
|
||||
Here is what happens in the background when this model is loaded:
|
||||
|
||||
1. The config is loaded
|
||||
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
1. The config is loaded.
|
||||
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
|
||||
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
|
||||
|
||||
That's it!
|
||||
|
||||
|
@ -30,15 +30,6 @@ def set_default_torch_dtype(dtype: torch.dtype):
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def is_transformers_impl_compatible(
|
||||
arch: str,
|
||||
module: Optional["transformers.PreTrainedModel"] = None) -> bool:
|
||||
mod = module or getattr(transformers, arch, None)
|
||||
if mod is None:
|
||||
return False
|
||||
return mod.is_backend_compatible()
|
||||
|
||||
|
||||
def resolve_transformers_arch(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
for i, arch in enumerate(architectures):
|
||||
@ -61,17 +52,26 @@ def resolve_transformers_arch(model_config: ModelConfig,
|
||||
revision=model_config.revision)
|
||||
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
|
||||
}
|
||||
custom_model_module = auto_modules.get("AutoModel")
|
||||
model_module = getattr(transformers, arch, None)
|
||||
if model_module is None:
|
||||
if "AutoModel" not in auto_map:
|
||||
raise ValueError(
|
||||
f"Cannot find model module. '{arch}' is not a registered "
|
||||
"model in the Transformers library (only relevant if the "
|
||||
"model is meant to be in Transformers) and 'AutoModel' is "
|
||||
"not present in the model config's 'auto_map' (relevant "
|
||||
"if the model is custom).")
|
||||
model_module = auto_modules["AutoModel"]
|
||||
# TODO(Isotr0py): Further clean up these raises.
|
||||
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
|
||||
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
||||
if not is_transformers_impl_compatible(arch, custom_model_module):
|
||||
if not model_module.is_backend_compatible():
|
||||
raise ValueError(
|
||||
f"The Transformers implementation of {arch} is not "
|
||||
"compatible with vLLM.")
|
||||
architectures[i] = "TransformersForCausalLM"
|
||||
if model_config.model_impl == ModelImpl.AUTO:
|
||||
if not is_transformers_impl_compatible(arch, custom_model_module):
|
||||
if not model_module.is_backend_compatible():
|
||||
raise ValueError(
|
||||
f"{arch} has no vLLM implementation and the Transformers "
|
||||
"implementation is not compatible with vLLM. Try setting "
|
||||
|
Reference in New Issue
Block a user