mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 17:34:34 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			83 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			83 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
 | 
						|
import gradio as gr
 | 
						|
from openai import OpenAI
 | 
						|
 | 
						|
# Argument parser setup
 | 
						|
parser = argparse.ArgumentParser(
 | 
						|
    description='Chatbot Interface with Customizable Parameters')
 | 
						|
parser.add_argument('--model-url',
 | 
						|
                    type=str,
 | 
						|
                    default='http://localhost:8000/v1',
 | 
						|
                    help='Model URL')
 | 
						|
parser.add_argument('-m',
 | 
						|
                    '--model',
 | 
						|
                    type=str,
 | 
						|
                    required=True,
 | 
						|
                    help='Model name for the chatbot')
 | 
						|
parser.add_argument('--temp',
 | 
						|
                    type=float,
 | 
						|
                    default=0.8,
 | 
						|
                    help='Temperature for text generation')
 | 
						|
parser.add_argument('--stop-token-ids',
 | 
						|
                    type=str,
 | 
						|
                    default='',
 | 
						|
                    help='Comma-separated stop token IDs')
 | 
						|
parser.add_argument("--host", type=str, default=None)
 | 
						|
parser.add_argument("--port", type=int, default=8001)
 | 
						|
 | 
						|
# Parse the arguments
 | 
						|
args = parser.parse_args()
 | 
						|
 | 
						|
# Set OpenAI's API key and API base to use vLLM's API server.
 | 
						|
openai_api_key = "EMPTY"
 | 
						|
openai_api_base = args.model_url
 | 
						|
 | 
						|
# Create an OpenAI client to interact with the API server
 | 
						|
client = OpenAI(
 | 
						|
    api_key=openai_api_key,
 | 
						|
    base_url=openai_api_base,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def predict(message, history):
 | 
						|
    # Convert chat history to OpenAI format
 | 
						|
    history_openai_format = [{
 | 
						|
        "role": "system",
 | 
						|
        "content": "You are a great ai assistant."
 | 
						|
    }]
 | 
						|
    for human, assistant in history:
 | 
						|
        history_openai_format.append({"role": "user", "content": human})
 | 
						|
        history_openai_format.append({
 | 
						|
            "role": "assistant",
 | 
						|
            "content": assistant
 | 
						|
        })
 | 
						|
    history_openai_format.append({"role": "user", "content": message})
 | 
						|
 | 
						|
    # Create a chat completion request and send it to the API server
 | 
						|
    stream = client.chat.completions.create(
 | 
						|
        model=args.model,  # Model name to use
 | 
						|
        messages=history_openai_format,  # Chat history
 | 
						|
        temperature=args.temp,  # Temperature for text generation
 | 
						|
        stream=True,  # Stream response
 | 
						|
        extra_body={
 | 
						|
            'repetition_penalty':
 | 
						|
            1,
 | 
						|
            'stop_token_ids': [
 | 
						|
                int(id.strip()) for id in args.stop_token_ids.split(',')
 | 
						|
                if id.strip()
 | 
						|
            ] if args.stop_token_ids else []
 | 
						|
        })
 | 
						|
 | 
						|
    # Read and return generated text from response stream
 | 
						|
    partial_message = ""
 | 
						|
    for chunk in stream:
 | 
						|
        partial_message += (chunk.choices[0].delta.content or "")
 | 
						|
        yield partial_message
 | 
						|
 | 
						|
 | 
						|
# Create and launch a chat interface with Gradio
 | 
						|
gr.ChatInterface(predict).queue().launch(server_name=args.host,
 | 
						|
                                         server_port=args.port,
 | 
						|
                                         share=True)
 |