mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[doc] Add RAG Integration example (#17692)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
@ -11,6 +11,7 @@ helm
|
||||
lws
|
||||
modal
|
||||
open-webui
|
||||
retrieval_augmented_generation
|
||||
skypilot
|
||||
streamlit
|
||||
triton
|
||||
|
@ -0,0 +1,84 @@
|
||||
(deployment-retrieval-augmented-generation)=
|
||||
|
||||
# Retrieval-Augmented Generation
|
||||
|
||||
[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources.
|
||||
|
||||
Here are the integrations:
|
||||
- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus)
|
||||
- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus)
|
||||
|
||||
## vLLM + langchain
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Setup vLLM and langchain environment
|
||||
|
||||
```console
|
||||
pip install -U vllm \
|
||||
langchain_milvus langchain_openai \
|
||||
langchain_community beautifulsoup4 \
|
||||
langchain-text-splitters
|
||||
```
|
||||
|
||||
### Deploy
|
||||
|
||||
- Start the vLLM server with the supported embedding model, e.g.
|
||||
|
||||
```console
|
||||
# Start embedding service (port 8000)
|
||||
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
|
||||
```
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```console
|
||||
# Start chat service (port 8001)
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
|
||||
```
|
||||
|
||||
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_langchain.py>
|
||||
|
||||
- Run the script
|
||||
|
||||
```python
|
||||
python retrieval_augmented_generation_with_langchain.py
|
||||
```
|
||||
|
||||
## vLLM + llamaindex
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Setup vLLM and llamaindex environment
|
||||
|
||||
```console
|
||||
pip install vllm \
|
||||
llama-index llama-index-readers-web \
|
||||
llama-index-llms-openai-like \
|
||||
llama-index-embeddings-openai-like \
|
||||
llama-index-vector-stores-milvus \
|
||||
```
|
||||
|
||||
### Deploy
|
||||
|
||||
- Start the vLLM server with the supported embedding model, e.g.
|
||||
|
||||
```console
|
||||
# Start embedding service (port 8000)
|
||||
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
|
||||
```
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```console
|
||||
# Start chat service (port 8001)
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
|
||||
```
|
||||
|
||||
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_llamaindex.py>
|
||||
|
||||
- Run the script
|
||||
|
||||
```python
|
||||
python retrieval_augmented_generation_with_llamaindex.py
|
||||
```
|
@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Retrieval Augmented Generation (RAG) Implementation with Langchain
|
||||
==================================================================
|
||||
|
||||
This script demonstrates a RAG implementation using LangChain, Milvus
|
||||
and vLLM. RAG enhances LLM responses by retrieving relevant context
|
||||
from a document collection.
|
||||
|
||||
Features:
|
||||
- Web content loading and chunking
|
||||
- Vector storage with Milvus
|
||||
- Embedding generation with vLLM
|
||||
- Question answering with context
|
||||
|
||||
Prerequisites:
|
||||
1. Install dependencies:
|
||||
pip install -U vllm \
|
||||
langchain_milvus langchain_openai \
|
||||
langchain_community beautifulsoup4 \
|
||||
langchain-text-splitters
|
||||
|
||||
2. Start services:
|
||||
# Start embedding service (port 8000)
|
||||
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
|
||||
|
||||
# Start chat service (port 8001)
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
|
||||
|
||||
Usage:
|
||||
python retrieval_augmented_generation_with_langchain.py
|
||||
|
||||
Notes:
|
||||
- Ensure both vLLM services are running before executing
|
||||
- Default ports: 8000 (embedding), 8001 (chat)
|
||||
- First run may take time to download models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.document_loaders import WebBaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_milvus import Milvus
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
def load_and_split_documents(config: dict[str, Any]):
|
||||
"""
|
||||
Load and split documents from web URL
|
||||
"""
|
||||
try:
|
||||
loader = WebBaseLoader(web_paths=(config["url"], ))
|
||||
docs = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
)
|
||||
return text_splitter.split_documents(docs)
|
||||
except Exception as e:
|
||||
print(f"Error loading document from {config['url']}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def init_vectorstore(config: dict[str, Any], documents: list[Document]):
|
||||
"""
|
||||
Initialize vector store with documents
|
||||
"""
|
||||
return Milvus.from_documents(
|
||||
documents=documents,
|
||||
embedding=OpenAIEmbeddings(
|
||||
model=config["embedding_model"],
|
||||
openai_api_key=config["vllm_api_key"],
|
||||
openai_api_base=config["vllm_embedding_endpoint"],
|
||||
),
|
||||
connection_args={"uri": config["uri"]},
|
||||
drop_old=True,
|
||||
)
|
||||
|
||||
|
||||
def init_llm(config: dict[str, Any]):
|
||||
"""
|
||||
Initialize llm
|
||||
"""
|
||||
return ChatOpenAI(
|
||||
model=config["chat_model"],
|
||||
openai_api_key=config["vllm_api_key"],
|
||||
openai_api_base=config["vllm_chat_endpoint"],
|
||||
)
|
||||
|
||||
|
||||
def get_qa_prompt():
|
||||
"""
|
||||
Get question answering prompt template
|
||||
"""
|
||||
template = """You are an assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Question: {question}
|
||||
Context: {context}
|
||||
Answer:
|
||||
"""
|
||||
return PromptTemplate.from_template(template)
|
||||
|
||||
|
||||
def format_docs(docs: list[Document]):
|
||||
"""
|
||||
Format documents for prompt
|
||||
"""
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
|
||||
def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
|
||||
"""
|
||||
Set up question answering chain
|
||||
"""
|
||||
return ({
|
||||
"context": retriever | format_docs,
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser())
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Parse command line arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain')
|
||||
|
||||
# Add command line arguments
|
||||
parser.add_argument('--vllm-api-key',
|
||||
default="EMPTY",
|
||||
help='API key for vLLM compatible services')
|
||||
parser.add_argument('--vllm-embedding-endpoint',
|
||||
default="http://localhost:8000/v1",
|
||||
help='Base URL for embedding service')
|
||||
parser.add_argument('--vllm-chat-endpoint',
|
||||
default="http://localhost:8001/v1",
|
||||
help='Base URL for chat service')
|
||||
parser.add_argument('--uri',
|
||||
default="./milvus.db",
|
||||
help='URI for Milvus database')
|
||||
parser.add_argument(
|
||||
'--url',
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"quickstart.html"),
|
||||
help='URL of the document to process')
|
||||
parser.add_argument('--embedding-model',
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help='Model name for embeddings')
|
||||
parser.add_argument('--chat-model',
|
||||
default="qwen/Qwen1.5-0.5B-Chat",
|
||||
help='Model name for chat')
|
||||
parser.add_argument('-i',
|
||||
'--interactive',
|
||||
action='store_true',
|
||||
help='Enable interactive Q&A mode')
|
||||
parser.add_argument('-k',
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of top results to retrieve')
|
||||
parser.add_argument('-c',
|
||||
'--chunk-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Chunk size for document splitting')
|
||||
parser.add_argument('-o',
|
||||
'--chunk-overlap',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Chunk overlap for document splitting')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def init_config(args: Namespace):
|
||||
"""
|
||||
Initialize configuration settings from command line arguments
|
||||
"""
|
||||
|
||||
return {
|
||||
"vllm_api_key": args.vllm_api_key,
|
||||
"vllm_embedding_endpoint": args.vllm_embedding_endpoint,
|
||||
"vllm_chat_endpoint": args.vllm_chat_endpoint,
|
||||
"uri": args.uri,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model": args.chat_model,
|
||||
"url": args.url,
|
||||
"chunk_size": args.chunk_size,
|
||||
"chunk_overlap": args.chunk_overlap,
|
||||
"top_k": args.top_k
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
args = get_parser().parse_args()
|
||||
|
||||
# Initialize configuration
|
||||
config = init_config(args)
|
||||
|
||||
# Load and split documents
|
||||
documents = load_and_split_documents(config)
|
||||
|
||||
# Initialize vector store and retriever
|
||||
vectorstore = init_vectorstore(config, documents)
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": config["top_k"]})
|
||||
|
||||
# Initialize llm and prompt
|
||||
llm = init_llm(config)
|
||||
prompt = get_qa_prompt()
|
||||
|
||||
# Set up QA chain
|
||||
qa_chain = create_qa_chain(retriever, llm, prompt)
|
||||
|
||||
# Interactive mode
|
||||
if args.interactive:
|
||||
print("\nWelcome to Interactive Q&A System!")
|
||||
print("Enter 'q' or 'quit' to exit.")
|
||||
|
||||
while True:
|
||||
question = input("\nPlease enter your question: ")
|
||||
if question.lower() in ['q', 'quit']:
|
||||
print("\nThank you for using! Goodbye!")
|
||||
break
|
||||
|
||||
output = qa_chain.invoke(question)
|
||||
print(output)
|
||||
else:
|
||||
# Default single question mode
|
||||
question = ("How to install vLLM?")
|
||||
output = qa_chain.invoke(question)
|
||||
print("-" * 50)
|
||||
print(output)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
RAG (Retrieval Augmented Generation) Implementation with LlamaIndex
|
||||
================================================================
|
||||
|
||||
This script demonstrates a RAG system using:
|
||||
- LlamaIndex: For document indexing and retrieval
|
||||
- Milvus: As vector store backend
|
||||
- vLLM: For embedding and text generation
|
||||
|
||||
Features:
|
||||
1. Document Loading & Processing
|
||||
2. Embedding & Storage
|
||||
3. Query Processing
|
||||
|
||||
Requirements:
|
||||
1. Install dependencies:
|
||||
pip install llama-index llama-index-readers-web \
|
||||
llama-index-llms-openai-like \
|
||||
llama-index-embeddings-openai-like \
|
||||
llama-index-vector-stores-milvus \
|
||||
|
||||
2. Start services:
|
||||
# Start embedding service (port 8000)
|
||||
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
|
||||
|
||||
# Start chat service (port 8001)
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
|
||||
|
||||
Usage:
|
||||
python retrieval_augmented_generation_with_llamaindex.py
|
||||
|
||||
Notes:
|
||||
- Ensure both vLLM services are running before executing
|
||||
- Default ports: 8000 (embedding), 8001 (chat)
|
||||
- First run may take time to download models
|
||||
"""
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Settings, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
from llama_index.readers.web import SimpleWebPageReader
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
|
||||
def init_config(args: Namespace):
|
||||
"""Initialize configuration with command line arguments"""
|
||||
return {
|
||||
"url": args.url,
|
||||
"embedding_model": args.embedding_model,
|
||||
"chat_model": args.chat_model,
|
||||
"vllm_api_key": args.vllm_api_key,
|
||||
"embedding_endpoint": args.embedding_endpoint,
|
||||
"chat_endpoint": args.chat_endpoint,
|
||||
"db_path": args.db_path,
|
||||
"chunk_size": args.chunk_size,
|
||||
"chunk_overlap": args.chunk_overlap,
|
||||
"top_k": args.top_k
|
||||
}
|
||||
|
||||
|
||||
def load_documents(url: str) -> list:
|
||||
"""Load and process web documents"""
|
||||
return SimpleWebPageReader(html_to_text=True).load_data([url])
|
||||
|
||||
|
||||
def setup_models(config: dict[str, Any]):
|
||||
"""Configure embedding and chat models"""
|
||||
Settings.embed_model = OpenAILikeEmbedding(
|
||||
api_base=config["embedding_endpoint"],
|
||||
api_key=config["vllm_api_key"],
|
||||
model_name=config["embedding_model"],
|
||||
)
|
||||
|
||||
Settings.llm = OpenAILike(
|
||||
model=config["chat_model"],
|
||||
api_key=config["vllm_api_key"],
|
||||
api_base=config["chat_endpoint"],
|
||||
context_window=128000,
|
||||
is_chat_model=True,
|
||||
is_function_calling_model=False,
|
||||
)
|
||||
|
||||
Settings.transformations = [
|
||||
SentenceSplitter(
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def setup_vector_store(db_path: str) -> MilvusVectorStore:
|
||||
"""Initialize vector store"""
|
||||
sample_emb = Settings.embed_model.get_text_embedding("test")
|
||||
print(f"Embedding dimension: {len(sample_emb)}")
|
||||
return MilvusVectorStore(uri=db_path, dim=len(sample_emb), overwrite=True)
|
||||
|
||||
|
||||
def create_index(documents: list, vector_store: MilvusVectorStore):
|
||||
"""Create document index"""
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
return VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
|
||||
|
||||
def query_document(index: VectorStoreIndex, question: str, top_k: int):
|
||||
"""Query document with given question"""
|
||||
query_engine = index.as_query_engine(similarity_top_k=top_k)
|
||||
return query_engine.query(question)
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='RAG with vLLM and LlamaIndex')
|
||||
|
||||
# Add command line arguments
|
||||
parser.add_argument(
|
||||
'--url',
|
||||
default=("https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"quickstart.html"),
|
||||
help='URL of the document to process')
|
||||
parser.add_argument('--embedding-model',
|
||||
default="ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
help='Model name for embeddings')
|
||||
parser.add_argument('--chat-model',
|
||||
default="qwen/Qwen1.5-0.5B-Chat",
|
||||
help='Model name for chat')
|
||||
parser.add_argument('--vllm-api-key',
|
||||
default="EMPTY",
|
||||
help='API key for vLLM compatible services')
|
||||
parser.add_argument('--embedding-endpoint',
|
||||
default="http://localhost:8000/v1",
|
||||
help='Base URL for embedding service')
|
||||
parser.add_argument('--chat-endpoint',
|
||||
default="http://localhost:8001/v1",
|
||||
help='Base URL for chat service')
|
||||
parser.add_argument('--db-path',
|
||||
default="./milvus_demo.db",
|
||||
help='Path to Milvus database')
|
||||
parser.add_argument('-i',
|
||||
'--interactive',
|
||||
action='store_true',
|
||||
help='Enable interactive Q&A mode')
|
||||
parser.add_argument('-c',
|
||||
'--chunk-size',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Chunk size for document splitting')
|
||||
parser.add_argument('-o',
|
||||
'--chunk-overlap',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Chunk overlap for document splitting')
|
||||
parser.add_argument('-k',
|
||||
'--top-k',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of top results to retrieve')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
args = get_parser().parse_args()
|
||||
|
||||
# Initialize configuration
|
||||
config = init_config(args)
|
||||
|
||||
# Load documents
|
||||
documents = load_documents(config["url"])
|
||||
|
||||
# Setup models
|
||||
setup_models(config)
|
||||
|
||||
# Setup vector store
|
||||
vector_store = setup_vector_store(config["db_path"])
|
||||
|
||||
# Create index
|
||||
index = create_index(documents, vector_store)
|
||||
|
||||
if args.interactive:
|
||||
print("\nEntering interactive mode. Type 'quit' to exit.")
|
||||
while True:
|
||||
# Get user question
|
||||
question = input("\nEnter your question: ")
|
||||
|
||||
# Check for exit command
|
||||
if question.lower() in ['quit', 'exit', 'q']:
|
||||
print("Exiting interactive mode...")
|
||||
break
|
||||
|
||||
# Get and print response
|
||||
print("\n" + "-" * 50)
|
||||
print("Response:\n")
|
||||
response = query_document(index, question, config["top_k"])
|
||||
print(response)
|
||||
print("-" * 50)
|
||||
else:
|
||||
# Single query mode
|
||||
question = "How to install vLLM?"
|
||||
response = query_document(index, question, config["top_k"])
|
||||
print("-" * 50)
|
||||
print("Response:\n")
|
||||
print(response)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user