[Frontend] Add chunked processing to handle long inputs in embedding models (#22280)

Signed-off-by: x22x22 <wadeking@qq.com>
Signed-off-by: Kdump <rootshellexp@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Kdump
2025-08-13 19:14:24 +08:00
committed by GitHub
parent 0b1bdac6af
commit 653124bd46
6 changed files with 1603 additions and 3 deletions

View File

@ -0,0 +1,186 @@
# Long Text Embedding with Chunked Processing
This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length.
## 🚀 Quick Start
### Start the Server
Use the provided script to start a vLLM server with chunked processing enabled:
```bash
# Basic usage (supports very long texts up to ~3M tokens)
./service.sh
# Custom configuration with different models
MODEL_NAME="jinaai/jina-embeddings-v3" \
MAX_EMBED_LEN=1048576 \
./service.sh
# For extremely long documents
MODEL_NAME="intfloat/multilingual-e5-large" \
MAX_EMBED_LEN=3072000 \
./service.sh
```
### Test Long Text Embedding
Run the comprehensive test client:
```bash
python client.py
```
## 📁 Files
| File | Description |
|------|-------------|
| `service.sh` | Server startup script with chunked processing enabled |
| `client.py` | Comprehensive test client for long text embedding |
## ⚙️ Configuration
### Server Configuration
The key parameters for chunked processing are in the `--override-pooler-config`:
```json
{
"pooling_type": "auto",
"normalize": true,
"enable_chunked_processing": true,
"max_embed_len": 3072000
}
```
!!! note
`pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length.
#### Chunked Processing Behavior
Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length:
| Component | Behavior | Description |
|-----------|----------|-------------|
| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy |
| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts |
| **Performance** | Optimal | All chunks processed for complete semantic coverage |
### Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) |
| `PORT` | `31090` | Server port |
| `GPU_COUNT` | `1` | Number of GPUs to use |
| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) |
| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) |
| `API_KEY` | `EMPTY` | API key for authentication |
## 🔧 How It Works
1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables
2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity
3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy
4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
### Input Length Handling
- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens)
- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered
- **Exceeds max_embed_len**: Input is rejected with clear error message
- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN`
### Extreme Long Text Support
With `MAX_EMBED_LEN=3072000`, you can process:
- **Academic papers**: Full research papers with references
- **Legal documents**: Complete contracts and legal texts
- **Books**: Entire chapters or small books
- **Code repositories**: Large codebases and documentation
## 📊 Performance Characteristics
### Chunked Processing Performance
| Aspect | Behavior | Performance |
|--------|----------|-------------|
| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length |
| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead |
| **Memory Usage** | Proportional to number of chunks | Moderate, scalable |
| **Semantic Quality** | Complete text coverage | Optimal for long documents |
## 🧪 Test Cases
The test client demonstrates:
-**Short text**: Normal processing (baseline)
-**Medium text**: Single chunk processing
-**Long text**: Multi-chunk processing with aggregation
-**Very long text**: Many chunks processing
-**Extreme long text**: Document-level processing (100K+ tokens)
-**Batch processing**: Mixed-length inputs in one request
-**Consistency**: Reproducible results across runs
## 🐛 Troubleshooting
### Common Issues
1. **Chunked processing not enabled**:
```log
ValueError: This model's maximum position embeddings length is 4096 tokens...
```
**Solution**: Ensure `enable_chunked_processing: true` in pooler config
2. **Input exceeds max_embed_len**:
```log
ValueError: This model's maximum embedding input length is 3072000 tokens...
```
**Solution**: Increase `max_embed_len` in pooler config or reduce input length
3. **Memory errors**:
```log
RuntimeError: CUDA out of memory
```
**Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs
4. **Slow processing**:
**Expected**: Long text takes more time due to multiple inference calls
### Debug Information
Server logs show chunked processing activity:
```log
INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing
INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096)
```
## 🤝 Contributing
To extend chunked processing support to other embedding models:
1. Check model compatibility with the pooling architecture
2. Test with various text lengths
3. Validate embedding quality compared to single-chunk processing
4. Submit PR with test cases and documentation updates
## 🆕 Enhanced Features
### max_embed_len Parameter
The new `max_embed_len` parameter provides:
- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable
- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len`
- **Extreme Length Support**: Process documents with millions of tokens
- **Clear Error Messages**: Better feedback when inputs exceed limits
- **Backward Compatibility**: Existing configurations continue to work

View File

@ -0,0 +1,366 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example script demonstrating long text embedding with chunked processing in vLLM.
This example shows how to use vLLM's chunked processing feature to handle text
inputs that exceed the model's maximum token length. The feature automatically
splits long text into chunks and handles different pooling types optimally.
Prerequisites:
1. Start vLLM server with chunked processing enabled:
# MEAN pooling (processes all chunks, recommended for complete coverage)
vllm serve intfloat/multilingual-e5-large \
--override-pooler-config \
'{"pooling_type": "MEAN", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
--served-model-name multilingual-e5-large \
--trust-remote-code \
--port 31090 \
--api-key your-api-key
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
vllm serve BAAI/bge-large-en-v1.5 \
--override-pooler-config \
'{"pooling_type": "CLS", "normalize": true, ' \
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
--served-model-name bge-large-en-v1.5 \
--trust-remote-code \
--port 31090 \
--api-key your-api-key
2. Install required dependencies:
pip install openai requests
"""
import time
import numpy as np
from openai import OpenAI
# Configuration
API_KEY = "your-api-key" # Replace with your actual API key
BASE_URL = "http://localhost:31090/v1"
MODEL_NAME = "multilingual-e5-large"
def generate_long_text(base_text: str, repeat_count: int) -> str:
"""Generate long text by repeating base text."""
return base_text * repeat_count
def test_embedding_with_different_lengths():
"""Test embedding generation with different text lengths."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
# Test cases with different text lengths
test_cases = [
{
"name": "Short Text",
"text": "Hello, this is a short text for embedding.",
"expected_chunks": 1,
},
{
"name": "Medium Text",
"text": generate_long_text(
"This is a medium-length text that should fit within the "
"model's context window. " * 20,
2,
),
"expected_chunks": 1,
},
{
"name": "Long Text (2 chunks)",
"text": generate_long_text(
"This is a very long text that will exceed the model's "
"maximum context length and trigger chunked processing. " * 50,
5,
),
"expected_chunks": 2,
},
{
"name": "Very Long Text (3+ chunks)",
"text": generate_long_text(
"This text is extremely long and will definitely "
"require multiple chunks for processing. " * 100,
10,
),
"expected_chunks": 3,
},
]
print("🧪 Testing vLLM Long Text Embedding with Chunked Processing")
print("=" * 70)
for i, test_case in enumerate(test_cases, 1):
print(f"\n📝 Test {i}: {test_case['name']}")
print(f"Text length: {len(test_case['text'])} characters")
try:
start_time = time.time()
response = client.embeddings.create(
input=test_case["text"], model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
# Extract embedding data
embedding = response.data[0].embedding
embedding_dim = len(embedding)
print("✅ Success!")
print(f" - Embedding dimension: {embedding_dim}")
print(f" - Processing time: {processing_time:.2f}s")
print(f" - Expected chunks: ~{test_case['expected_chunks']}")
print(f" - First 5 values: {embedding[:5]}")
except Exception as e:
print(f"❌ Failed: {str(e)}")
def test_batch_embedding():
"""Test batch embedding with mixed-length inputs."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n🔄 Testing Batch Embedding with Mixed Lengths")
print("=" * 50)
# Mix of short and long texts
batch_inputs = [
"Short text 1",
generate_long_text("Medium length text that fits in one chunk. " * 20, 1),
"Another short text",
generate_long_text("Long text requiring chunked processing. " * 100, 5),
]
try:
start_time = time.time()
response = client.embeddings.create(
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
print("✅ Batch processing successful!")
print(f" - Number of inputs: {len(batch_inputs)}")
print(f" - Number of embeddings: {len(response.data)}")
print(f" - Total processing time: {processing_time:.2f}s")
print(
f" - Average time per input: {processing_time / len(batch_inputs):.2f}s"
)
for i, data in enumerate(response.data):
input_length = len(batch_inputs[i])
embedding_dim = len(data.embedding)
print(
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding"
)
except Exception as e:
print(f"❌ Batch processing failed: {str(e)}")
def test_multiple_long_texts_batch():
"""Test batch processing with multiple long texts to verify chunk ID uniqueness."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)")
print("=" * 70)
# Create multiple distinct long texts that will all require chunking
# Note: All pooling types now use MEAN aggregation across chunks:
# - Native pooling (MEAN/CLS/LAST) is used within each chunk
# - MEAN aggregation combines results across all chunks
# - Full semantic coverage for all pooling types
long_texts = [
generate_long_text(
"First long document about artificial intelligence and machine learning. "
* 80,
6,
),
generate_long_text(
"Second long document about natural language processing and transformers. "
* 80,
6,
),
generate_long_text(
"Third long document about computer vision and neural networks. " * 80, 6
),
]
# Add some short texts to mix things up
batch_inputs = [
"Short text before long texts",
long_texts[0],
"Short text between long texts",
long_texts[1],
long_texts[2],
"Short text after long texts",
]
print("📊 Batch composition:")
for i, text in enumerate(batch_inputs):
length = len(text)
text_type = "Long (will be chunked)" if length > 5000 else "Short"
print(f" - Input {i + 1}: {length} chars ({text_type})")
try:
start_time = time.time()
response = client.embeddings.create(
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
)
end_time = time.time()
processing_time = end_time - start_time
print("\n✅ Multiple long texts batch processing successful!")
print(f" - Number of inputs: {len(batch_inputs)}")
print(f" - Number of embeddings returned: {len(response.data)}")
print(f" - Total processing time: {processing_time:.2f}s")
# Verify each embedding is different (no incorrect aggregation)
embeddings = [data.embedding for data in response.data]
if len(embeddings) >= 3:
import numpy as np
# Compare embeddings of the long texts (indices 1, 3, 4)
long_embeddings = [
np.array(embeddings[1]), # First long text
np.array(embeddings[3]), # Second long text
np.array(embeddings[4]), # Third long text
]
print("\n🔍 Verifying embedding uniqueness:")
for i in range(len(long_embeddings)):
for j in range(i + 1, len(long_embeddings)):
cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / (
np.linalg.norm(long_embeddings[i])
* np.linalg.norm(long_embeddings[j])
)
print(
f" - Similarity between long text {i + 1} and {j + 1}: "
f"{cosine_sim:.4f}"
)
if (
cosine_sim < 0.9
): # Different content should have lower similarity
print(" ✅ Good: Embeddings are appropriately different")
else:
print(
" ⚠️ High similarity - may indicate chunk "
"aggregation issue"
)
print("\n📋 Per-input results:")
for i, data in enumerate(response.data):
input_length = len(batch_inputs[i])
embedding_dim = len(data.embedding)
embedding_norm = np.linalg.norm(data.embedding)
print(
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D "
f"embedding (norm: {embedding_norm:.4f})"
)
print(
"\n✅ This test verifies the fix for chunk ID collisions in "
"batch processing"
)
print(" - Before fix: Multiple long texts would have conflicting chunk IDs")
print(" - After fix: Each prompt's chunks have unique IDs with prompt index")
except Exception as e:
print(f"❌ Multiple long texts batch test failed: {str(e)}")
print(" This might indicate the chunk ID collision bug is present!")
def test_embedding_consistency():
"""Test that chunked processing produces consistent results."""
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
print("\n🔍 Testing Embedding Consistency")
print("=" * 40)
# Use the same long text multiple times
long_text = generate_long_text(
"Consistency test text for chunked processing validation. " * 50, 3
)
embeddings = []
try:
for i in range(3):
response = client.embeddings.create(
input=long_text, model=MODEL_NAME, encoding_format="float"
)
embeddings.append(response.data[0].embedding)
print(f" - Generated embedding {i + 1}")
# Check consistency (embeddings should be identical)
if len(embeddings) >= 2:
# Calculate similarity between first two embeddings
emb1 = np.array(embeddings[0])
emb2 = np.array(embeddings[1])
# Cosine similarity
cosine_sim = np.dot(emb1, emb2) / (
np.linalg.norm(emb1) * np.linalg.norm(emb2)
)
print("✅ Consistency test completed!")
print(f" - Cosine similarity between runs: {cosine_sim:.6f}")
print(" - Expected: ~1.0 (identical embeddings)")
if cosine_sim > 0.999:
print(" - ✅ High consistency achieved!")
else:
print(" - ⚠️ Consistency may vary due to numerical precision")
except Exception as e:
print(f"❌ Consistency test failed: {str(e)}")
def main():
"""Main function to run all tests."""
print("🚀 vLLM Long Text Embedding Client")
print(f"📡 Connecting to: {BASE_URL}")
print(f"🤖 Model: {MODEL_NAME}")
masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****"
print(f"🔑 API Key: {masked_key}")
# Run all test cases
test_embedding_with_different_lengths()
test_batch_embedding()
test_multiple_long_texts_batch()
test_embedding_consistency()
print("\n" + "=" * 70)
print("🎉 All tests completed!")
print("\n💡 Key Features Demonstrated:")
print(" - ✅ Automatic chunked processing for long text")
print(" - ✅ Seamless handling of mixed-length batches")
print(" - ✅ Multiple long texts in single batch (chunk ID fix)")
print(" - ✅ Unified chunked processing:")
print(" • Native pooling used within each chunk")
print(" • MEAN aggregation across all chunks")
print(" • Complete semantic coverage for all pooling types")
print(" - ✅ Consistent embedding generation")
print(" - ✅ Backward compatibility with short text")
print("\n📚 For more information, see:")
print(
" - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html"
)
print(" - Chunked Processing Guide: openai_embedding_long_text.md")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,137 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vLLM Embedding Server with Enhanced Chunked Processing
# This script starts a vLLM server with chunked processing enabled for long text embedding.
# Now supports proper pooling type validation and model-specific configurations.
set -euo pipefail
# Configuration
MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"}
MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"}
PORT=${PORT:-31090}
GPU_COUNT=${GPU_COUNT:-1}
MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000}
API_KEY=${API_KEY:-"your-api-key"}
# Enhanced pooling configuration with model-specific defaults
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
export VLLM_ENABLE_CHUNKED_PROCESSING=true
export CUDA_VISIBLE_DEVICES=2,3,4,5
# export VLLM_ATTENTION_BACKEND=XFORMERS
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
echo "=================================================================="
# Environment variables for optimization
export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Function to determine optimal pooling type for known models
get_optimal_pooling_type() {
local model="$1"
case "$model" in
*"e5-"* | *"multilingual-e5"*)
echo "MEAN" # E5 series native pooling
;;
*"bge-"*)
echo "CLS" # BGE series native pooling
;;
*"gte-"*)
echo "LAST" # GTE series native pooling
;;
*"sentence-t5"* | *"st5"*)
echo "MEAN" # Sentence-T5 native pooling
;;
*"jina-embeddings"*)
echo "MEAN" # Jina embeddings native pooling
;;
*"Qwen"*"Embedding"*)
echo "LAST" # Qwen embeddings native pooling
;;
*)
echo "MEAN" # Default native pooling for unknown models
;;
esac
}
# Auto-detect pooling type if not explicitly set
if [ "$POOLING_TYPE" = "auto" ]; then
POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME")
echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME"
fi
# Display configuration
echo "📋 Configuration:"
echo " - Model: $MODEL_NAME"
echo " - Port: $PORT"
echo " - GPU Count: $GPU_COUNT"
echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}"
echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens"
echo " - Native Pooling Type: $POOLING_TYPE + Normalization"
echo " - Cross-chunk Aggregation: MEAN (automatic)"
echo ""
# Validate GPU availability
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --list-gpus | wc -l)
echo "🖥️ Available GPUs: $gpu_count"
if [ "$GPU_COUNT" -gt "$gpu_count" ]; then
echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available"
echo " Adjusting to use $gpu_count GPUs"
GPU_COUNT=$gpu_count
fi
else
echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped."
fi
# Chunked processing uses unified MEAN aggregation
echo " Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks"
echo " - All chunks processed for complete semantic coverage"
echo " - Weighted averaging based on chunk token counts"
echo ""
echo "🔧 Starting server with enhanced chunked processing configuration..."
# Build pooler config JSON
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
# Start vLLM server with enhanced chunked processing
vllm serve "$MODEL_NAME" \
--tensor-parallel-size "$GPU_COUNT" \
--enforce-eager \
--override-pooler-config "$POOLER_CONFIG" \
--served-model-name ${MODEL_CODE} \
--api-key "$API_KEY" \
--trust-remote-code \
--port "$PORT" \
--host 0.0.0.0
echo ""
echo "✅ vLLM Embedding Server started successfully!"
echo ""
echo "📡 Server Information:"
echo " - Base URL: http://localhost:$PORT"
echo " - Model Code: ${MODEL_CODE}"
echo " - API Key: $API_KEY"
echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN"
echo ""
echo "🧪 Test the server with:"
echo " python examples/online_serving/openai_embedding_long_text_client.py"
echo ""
echo "📚 Enhanced features enabled:"
echo " ✅ Intelligent native pooling type detection"
echo " ✅ Unified MEAN aggregation for chunked processing"
echo " ✅ Model-specific native pooling optimization"
echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)"
echo " ✅ Complete semantic coverage for all pooling types"
echo " ✅ OpenAI-compatible API"
echo " ✅ GPU acceleration"
echo ""
echo "🔧 Advanced usage:"
echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection"
echo " - Set MAX_EMBED_LEN to adjust maximum input length"
echo " - All pooling types use MEAN aggregation across chunks"

View File

@ -0,0 +1,441 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test cases for long text embedding with automatic chunking mechanism.
This test suite validates vLLM's automatic chunking functionality for handling
text inputs that exceed the model's maximum token length, specifically targeting
the intfloat/multilingual-e5-small model (max token length: 512).
"""
import random
import openai
import pytest
import pytest_asyncio
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...utils import RemoteOpenAIServer
def _generate_random_text(word_count: int) -> str:
"""Generate random text with approximately the specified word count."""
# Common English words with focus on verbs and nouns for realistic text
common_words = [
# Essential articles and pronouns (minimal)
"the",
"and",
"you",
"they",
"this",
"that",
"these",
"those",
# Action verbs
"create",
"build",
"develop",
"design",
"implement",
"execute",
"analyze",
"process",
"generate",
"calculate",
"evaluate",
"optimize",
"transform",
"integrate",
"configure",
"deploy",
"monitor",
"manage",
"discover",
"explore",
"investigate",
"research",
"study",
"examine",
"improve",
"enhance",
"upgrade",
"modify",
"update",
"maintain",
"solve",
"resolve",
"handle",
"address",
"tackle",
"overcome",
"communicate",
"collaborate",
"coordinate",
"organize",
"plan",
"achieve",
"accomplish",
"complete",
"finish",
"deliver",
"provide",
# Technology and science nouns
"system",
"application",
"software",
"hardware",
"network",
"database",
"algorithm",
"model",
"framework",
"platform",
"interface",
"protocol",
"architecture",
"infrastructure",
"component",
"module",
"service",
"technology",
"innovation",
"solution",
"methodology",
"approach",
"artificial",
"intelligence",
"machine",
"learning",
"neural",
"network",
"computer",
"processor",
"memory",
"storage",
"computation",
"data",
"information",
"knowledge",
"insight",
"pattern",
"trend",
"analysis",
"research",
"development",
"engineering",
"science",
"mathematics",
"statistics",
"probability",
"optimization",
"performance",
"efficiency",
# General nouns
"project",
"team",
"organization",
"company",
"business",
"industry",
"market",
"customer",
"user",
"client",
"product",
"feature",
"function",
"requirement",
"specification",
"documentation",
"report",
"result",
"outcome",
"impact",
"benefit",
"advantage",
"challenge",
"problem",
"opportunity",
"strategy",
"goal",
"objective",
"target",
"milestone",
"process",
"procedure",
"workflow",
"pipeline",
"operation",
"task",
"activity",
"event",
"session",
"meeting",
"discussion",
"decision"
]
words = []
for _ in range(word_count):
words.append(random.choice(common_words))
# Add some punctuation for more realistic text
text = " ".join(words)
# Add periods every 10-20 words
words_list = text.split()
result = []
for i, word in enumerate(words_list):
result.append(word)
if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1):
result[-1] += "."
return " ".join(result)
MODEL_NAME = "intfloat/multilingual-e5-small"
DTYPE = "bfloat16"
# Test text: Generate text with approximately 1500 words to exceed 1024 tokens
LONG_TEXT_1500_WORDS = _generate_random_text(1500)
# Test text: Generate text with approximately 2500 words to exceed 2048 tokens
LONG_TEXT_2500_WORDS = _generate_random_text(2500)
@pytest.fixture(scope="module")
def server_with_chunked_processing():
"""Start server with automatic chunking processing enabled."""
args = [
"--runner",
"pooling",
"--dtype",
DTYPE,
"--enforce-eager",
"--max-model-len",
"512", # Set smaller max_model_len to trigger chunking mechanism
'--override-pooler-config',
('{"pooling_type": "MEAN", "normalize": true, '
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
"--gpu-memory-utilization",
"0.8",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_with_chunked_processing(server_with_chunked_processing):
"""Create async client with chunking processing support."""
async with server_with_chunked_processing.get_async_client(
) as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_long_text_embedding_1500_chars(
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
"""Test embedding processing for ~1500 character long text
(~1028 tokens, exceeding 512 token limit)."""
# Verify text length
# Verify text has sufficient word count (approximately 1500 words)
word_count = len(LONG_TEXT_1500_WORDS.split())
assert word_count >= 1400, (
f"Test text word count insufficient: {word_count} words")
# Send embedding request
embedding_response = await client_with_chunked_processing.embeddings.create(
model=model_name,
input=[LONG_TEXT_1500_WORDS],
encoding_format="float",
)
# Verify response structure
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding
) == 384 # multilingual-e5-small embedding dimension
assert embeddings.usage.completion_tokens == 0
# Due to chunked processing, token count should
# reflect actual processed tokens
# With ~1500 words, we expect roughly
# 1024+ tokens (exceeding 512 token limit)
# Should exceed single chunk limit of 512
assert embeddings.usage.prompt_tokens > 800
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
# Verify embedding vector validity
embedding_vector = embeddings.data[0].embedding
assert all(
isinstance(x, float)
for x in embedding_vector), "Embedding vector should contain floats"
assert not all(
x == 0
for x in embedding_vector), "Embedding vector should not be all zeros"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_long_text_embedding_2500_chars(
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
"""Test embedding processing for ~2500 character long text
(~2048 tokens, requiring multiple chunks)."""
# Verify text length
# Verify text has sufficient word count (approximately 2500 words)
word_count = len(LONG_TEXT_2500_WORDS.split())
assert word_count >= 2300, (
f"Test text word count insufficient: {word_count} words")
# Send embedding request
embedding_response = await client_with_chunked_processing.embeddings.create(
model=model_name,
input=[LONG_TEXT_2500_WORDS],
encoding_format="float",
)
# Verify response structure
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding
) == 384 # multilingual-e5-small embedding dimension
assert embeddings.usage.completion_tokens == 0
# Due to chunked processing, token count should
# reflect actual processed tokens
# With ~2500 words, we expect
# roughly 2048+ tokens (requiring multiple chunks)
# Should require multiple chunks for processing
assert embeddings.usage.prompt_tokens > 1500
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
# Verify embedding vector validity
embedding_vector = embeddings.data[0].embedding
assert all(
isinstance(x, float)
for x in embedding_vector), "Embedding vector should contain floats"
assert not all(
x == 0
for x in embedding_vector), "Embedding vector should not be all zeros"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_long_text_embedding(
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
"""Test batch long text embedding processing."""
input_texts = [
LONG_TEXT_1500_WORDS,
LONG_TEXT_2500_WORDS,
"This is a short text test.", # Short text for comparison
]
# Send batch embedding request
embedding_response = await client_with_chunked_processing.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
# Verify response structure
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 3 # Three input texts
# Verify each embedding dimension
for i, embedding_data in enumerate(embeddings.data):
assert len(embedding_data.embedding) == 384
assert embedding_data.index == i
# Verify embedding vector validity
embedding_vector = embedding_data.embedding
assert all(isinstance(x, float) for x in embedding_vector)
assert not all(x == 0 for x in embedding_vector)
# Verify token usage
assert embeddings.usage.completion_tokens == 0
# Total token count should be very substantial
assert embeddings.usage.prompt_tokens > 1000
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chunked_vs_normal_consistency(
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
"""Test consistency between chunked and
normal processing (using short text)."""
# Use a short text within the 512 token limit
short_text = ("Artificial intelligence technology is changing our world, "
"bringing unprecedented opportunities and challenges.")
# Send embedding request
embedding_response = await client_with_chunked_processing.embeddings.create(
model=model_name,
input=[short_text],
encoding_format="float",
)
# Verify response structure
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 384
assert embeddings.usage.completion_tokens == 0
# Short text should not require chunked processing
assert embeddings.usage.prompt_tokens < 512
assert embeddings.usage.total_tokens == embeddings.usage.prompt_tokens
# 验证embedding向量的有效性
embedding_vector = embeddings.data[0].embedding
assert all(isinstance(x, float) for x in embedding_vector)
assert not all(x == 0 for x in embedding_vector)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chunked_processing_response_format(
client_with_chunked_processing: openai.AsyncOpenAI, model_name: str):
"""Test response format and structure during chunked processing."""
# Test with long text to trigger chunking
embedding_response = await client_with_chunked_processing.embeddings.create(
model=model_name,
input=[LONG_TEXT_1500_WORDS],
encoding_format="float",
)
# Verify response structure
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert embeddings.data[0].object == "embedding"
assert embeddings.data[0].index == 0
# Verify embedding vector properties
embedding_vector = embeddings.data[0].embedding
import math
vector_norm = math.sqrt(sum(x * x for x in embedding_vector))
# Check that the vector is normalized
# (default behavior for most embedding models)
assert 0.8 < vector_norm < 1.2, (
f"Vector norm should be reasonable, actual: {vector_norm}")

View File

@ -2598,6 +2598,25 @@ class PoolerConfig:
``math-shepherd-mistral-7b-prm`` model.
"""
enable_chunked_processing: Optional[bool] = None
"""
Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into
chunks, processed separately, and then aggregated using weighted averaging.
This allows embedding models to handle arbitrarily long text without CUDA
errors. Defaults to False.
"""
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
This parameter enables accepting long inputs without requiring
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
max_embed_len, it will be handled according to the original max_model_len
validation logic. Defaults to None (i.e. set to max_model_len).
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from typing import Final, Literal, Optional, Union, cast
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, Literal, Optional, Union, cast
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never, override
@ -12,19 +14,28 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing,
ServeContext)
RequestPrompt,
ServeContext,
TextTokensPrompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
PoolingOutput, PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
logger = init_logger(__name__)
@ -46,6 +57,17 @@ def _get_embedding(
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pooler_config = self.model_config.pooler_config
# Avoid repeated attribute lookups
self.supports_chunked_processing = bool(
pooler_config and pooler_config.enable_chunked_processing)
self.max_embed_len = (pooler_config.max_embed_len if pooler_config
and pooler_config.max_embed_len else None)
@override
async def _preprocess(
self,
@ -129,6 +151,435 @@ class EmbeddingMixin(OpenAIServing):
usage=usage,
)
def _get_max_position_embeddings(self) -> int:
"""Get the model's effective maximum sequence length for chunking."""
return self.model_config.max_model_len
def _should_use_chunked_processing(self, request) -> bool:
"""Check if chunked processing should be used for this request."""
return isinstance(
request,
(EmbeddingCompletionRequest,
EmbeddingChatRequest)) and self.supports_chunked_processing
async def _process_chunked_request(
self,
ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt,
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings()
# Process all chunks for MEAN aggregation
for chunk_idx, chunk_tokens in enumerate(
chunk_list(token_ids, max_pos_embeddings)):
# Create a request ID for this chunk
chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
f"chunk-{chunk_idx}")
# Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(
prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens)
# Log the chunk
self._log_inputs(chunk_request_id,
chunk_request_prompt,
params=pooling_params,
lora_request=ctx.lora_request)
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(original_generator)
return generators
def _validate_input(
self,
request,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
"""Override to support chunked processing for embedding requests."""
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingCompletionRequest, EmbeddingChatRequest)):
# Check if chunked processing is enabled for pooling models
enable_chunked = self._should_use_chunked_processing(request)
# Use max_position_embeddings for chunked processing decisions
max_pos_embeddings = self._get_max_position_embeddings()
# Determine the effective max length for validation
if self.max_embed_len is not None:
# Use max_embed_len for validation instead of max_model_len
length_type = "maximum embedding input length"
max_length_value = self.max_embed_len
else:
# Fall back to max_model_len validation (original behavior)
length_type = "maximum context length"
max_length_value = self.max_model_len
validation_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input.")
chunked_processing_error_msg = (
"This model's {length_type} is {max_length_value} tokens. "
"However, you requested {token_num} tokens in the input for "
"embedding generation. Please reduce the length of the input "
"or enable chunked processing.")
# Check if input exceeds max length
if token_num > max_length_value:
raise ValueError(
validation_error_msg.format(
length_type=length_type,
max_length_value=max_length_value,
token_num=token_num))
# Check for chunked processing
# when exceeding max_position_embeddings
if token_num > max_pos_embeddings:
if enable_chunked:
# Allow long inputs when chunked processing is enabled
logger.info(
"Input length %s exceeds max_position_embeddings "
"%s, will use chunked processing", token_num,
max_pos_embeddings)
else:
raise ValueError(
chunked_processing_error_msg.format(
length_type="maximum position embeddings length",
max_length_value=max_pos_embeddings,
token_num=token_num))
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
request_prompt: RequestPrompt,
pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]],
prompt_index: int,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
self._log_inputs(request_id_item,
request_prompt,
params=pooling_params,
lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance
# of TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
# If no chunked processing needed, delegate to parent class
if not use_chunked:
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[AsyncGenerator[Union[RequestOutput,
PoolingRequestOutput],
None]] = []
try:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
# Verify and set the task for pooling params
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
max_pos_embeddings = self._get_max_position_embeddings()
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_prompt = ctx.request_prompts[i]
# Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(request_prompt):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
if (len(text_tokens_prompt["prompt_token_ids"])
> max_pos_embeddings):
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params,
trace_headers, i)
generators.extend(chunk_generators)
continue
# Normal processing for short prompts or non-token prompts
# Cast engine_prompt to the expected type for mypy
engine_prompt_typed = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = await self._create_single_prompt_generator(
ctx, engine_prompt_typed, request_prompt, pooling_params,
trace_headers, i)
generators.append(generator)
from vllm.utils import merge_async_iterators
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Collect and aggregate batch results
with support for chunked processing.
For chunked requests, performs online aggregation to
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
# Check if we used chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
if not use_chunked:
return await super()._collect_batch(ctx=ctx)
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
prompt_aggregators: dict[int, dict[str, Any]] = {}
short_prompts_results: dict[int, PoolingRequestOutput] = {}
async for result_idx, result in ctx.result_generator:
if "-chunk-" in result.request_id:
# Extract prompt_idx from chunked request_id
parts = result.request_id.split("-")
try:
prompt_idx = int(parts[parts.index("prompt") + 1])
except (ValueError, IndexError):
# Fallback: extract from result_idx if parsing fails
prompt_idx = result_idx
# Initialize aggregator for this prompt if needed
if prompt_idx not in prompt_aggregators:
prompt_aggregators[prompt_idx] = {
'weighted_sum': None,
'total_weight': 0,
'chunk_count': 0,
'request_id': result.request_id.split("-chunk-")[0]
}
aggregator = prompt_aggregators[prompt_idx]
# MEAN pooling with online weighted averaging
# Ensure result is PoolingRequestOutput
# for embedding processing
if not isinstance(result, PoolingRequestOutput):
return self.create_error_response(
f"Expected PoolingRequestOutput for "
f"chunked embedding, got "
f"{type(result).__name__}")
# Handle both PoolingOutput and
# EmbeddingOutput types
if hasattr(result.outputs, 'data'):
# PoolingOutput case
embedding_data = result.outputs.data
elif hasattr(result.outputs, 'embedding'):
# EmbeddingOutput case -
# convert embedding list to tensor
embedding_data = result.outputs.embedding
else:
return self.create_error_response(
f"Unsupported output type: "
f"{type(result.outputs).__name__}")
if not isinstance(embedding_data, torch.Tensor):
embedding_data = torch.tensor(embedding_data,
dtype=torch.float32)
if result.prompt_token_ids is None:
return self.create_error_response(
"prompt_token_ids cannot be None for "
"chunked processing")
weight = len(result.prompt_token_ids)
weighted_embedding = embedding_data.to(
dtype=torch.float32) * weight
if aggregator['weighted_sum'] is None:
# First chunk
aggregator['weighted_sum'] = weighted_embedding
else:
# Accumulate
aggregator['weighted_sum'] += weighted_embedding
aggregator['total_weight'] += weight
aggregator['chunk_count'] += 1
else:
# Non-chunked result - extract prompt_idx from request_id
parts = result.request_id.split("-")
try:
# Last part should be prompt index
prompt_idx = int(parts[-1])
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result)
# Finalize aggregated results
final_res_batch: list[Union[PoolingRequestOutput,
EmbeddingRequestOutput]] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
if prompt_idx in prompt_aggregators:
# Finalize MEAN aggregation for this chunked prompt
aggregator = prompt_aggregators[prompt_idx]
weighted_sum = aggregator['weighted_sum']
total_weight = aggregator['total_weight']
if (weighted_sum is not None
and isinstance(weighted_sum, torch.Tensor)
and isinstance(total_weight,
(int, float)) and total_weight > 0):
# Compute final mean embedding
final_embedding = weighted_sum / total_weight
# Create a PoolingRequestOutput
# for the aggregated result
pooling_output_data = PoolingOutput(
data=final_embedding)
# Get original prompt token IDs for this prompt
original_prompt = ctx.request_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a "
f"TextTokensPrompt")
original_token_ids = cast(
TextTokensPrompt,
original_prompt)["prompt_token_ids"]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator['request_id'],
prompt_token_ids=original_token_ids,
outputs=pooling_output_data,
finished=True)
final_res_batch.append(pooling_request_output)
else:
return self.create_error_response(
f"Failed to aggregate chunks "
f"for prompt {prompt_idx}")
elif prompt_idx in short_prompts_results:
final_res_batch.append(
cast(PoolingRequestOutput,
short_prompts_results[prompt_idx]))
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}")
ctx.final_res_batch = cast(
list[Union[RequestOutput, PoolingRequestOutput]],
final_res_batch)
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"