[Model] GritLM supports other attention backends (#18109)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-14 18:33:19 +08:00
committed by GitHub
parent 259127f8b8
commit d62a076e84
4 changed files with 85 additions and 108 deletions

View File

@ -11,7 +11,6 @@ from scipy.spatial.distance import cosine
from vllm import LLM, SamplingParams
from vllm.config import ModelConfig
from vllm.utils import STR_BACKEND_ENV_VAR
from ....utils import RemoteOpenAIServer
@ -117,44 +116,37 @@ def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001)
def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch,
vllm_runner):
# GritLM embedding implementation is only supported by XFormers backend.
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
def test_gritlm_offline_embedding(vllm_runner):
queries, q_instruction, documents, d_instruction = get_test_data()
queries, q_instruction, documents, d_instruction = get_test_data()
with vllm_runner(
MODEL_NAME,
task="embed",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
with vllm_runner(
MODEL_NAME,
task="embed",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
d_rep = run_llm_encode(
llm,
documents,
d_instruction,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
d_rep = run_llm_encode(
llm,
documents,
d_instruction,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
)
validate_embed_output(q_rep, d_rep)
validate_embed_output(q_rep, d_rep)
@pytest.mark.asyncio
async def test_gritlm_api_server_embedding():
queries, q_instruction, documents, d_instruction = get_test_data()
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
env_dict = {STR_BACKEND_ENV_VAR: "XFORMERS"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
with RemoteOpenAIServer(MODEL_NAME, args) as server:
client_embedding = server.get_async_client()
d_rep = await run_client_embeddings(
@ -172,35 +164,28 @@ async def test_gritlm_api_server_embedding():
def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner):
# GritLM embedding implementation is only supported by XFormers backend.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
with vllm_runner(
MODEL_NAME,
task="generate",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
with vllm_runner(
MODEL_NAME,
task="generate",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
llm = vllm_model.model
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
outputs = llm.generate(input, sampling_params=sampling_params)
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
outputs = llm.generate(input, sampling_params=sampling_params)
assert outputs[0].outputs[0].text == "The capital of France is Paris."
assert outputs[0].outputs[0].text == "The capital of France is Paris."
@pytest.mark.asyncio
async def test_gritlm_api_server_generate():
input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n"
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
env_dict = {"VLLM_USE_V1": "0", STR_BACKEND_ENV_VAR: "XFORMERS"}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as server:
with RemoteOpenAIServer(MODEL_NAME, args) as server:
client_generate = server.get_async_client()
outputs = await client_generate.completions.create(

View File

@ -1,22 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
from array import array
from typing import Optional, Union
from typing import Optional
import torch
import torch.nn as nn
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only
@ -204,39 +200,21 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix: str = "",
**kwargs,
) -> None:
# Use full attention for pooling
if vllm_config.model_config.runner_type == "pooling":
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = False
vllm_config.cache_config.sliding_window = None
for attr in ("sliding_window", "interleaved_sliding_window"):
if hasattr(hf_config, attr):
delattr(hf_config, attr)
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.runner_type = vllm_config.model_config.runner_type
self._pooler = GritLMPooler(vllm_config.model_config)
for layer in self.model.layers:
if self.runner_type == "pooling" and hasattr(layer, "self_attn"):
assert isinstance(layer.self_attn.attn.impl, XFormersImpl), (
"GritLM embedding is only supported by XFormers backend, "
"which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
# Change attention to non-causal for pooling tasks.
if self.runner_type == "pooling":
attn_metadata = get_forward_context().attn_metadata
assert attn_metadata.prefill_metadata.attn_bias is None
attn_metadata.prefill_metadata.attn_bias = [
BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens)
]
return super().forward(
input_ids=input_ids,
positions=positions,
**kwargs,
)
def pooler(
self,
hidden_states: torch.Tensor,

View File

@ -28,7 +28,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -96,19 +96,22 @@ class LlamaMLP(nn.Module):
class LlamaAttention(nn.Module):
def __init__(self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
@ -194,6 +197,7 @@ class LlamaAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
@ -238,6 +242,15 @@ class LlamaDecoderLayer(nn.Module):
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
# By default, Llama uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
@ -252,6 +265,7 @@ class LlamaDecoderLayer(nn.Module):
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,

View File

@ -100,19 +100,19 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str,
Any]] = None) -> None:
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()