mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-31 06:14:38 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| This example shows how to use Ray Data for running offline batch inference
 | |
| distributively on a multi-nodes cluster.
 | |
| 
 | |
| Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
 | |
| """
 | |
| 
 | |
| from typing import Any, Dict, List
 | |
| 
 | |
| import numpy as np
 | |
| import ray
 | |
| from packaging.version import Version
 | |
| from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
 | |
| 
 | |
| from vllm import LLM, SamplingParams
 | |
| 
 | |
| assert Version(ray.__version__) >= Version(
 | |
|     "2.22.0"), "Ray version must be at least 2.22.0"
 | |
| 
 | |
| # Create a sampling params object.
 | |
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | |
| 
 | |
| # Set tensor parallelism per instance.
 | |
| tensor_parallel_size = 1
 | |
| 
 | |
| # Set number of instances. Each instance will use tensor_parallel_size GPUs.
 | |
| num_instances = 1
 | |
| 
 | |
| 
 | |
| # Create a class to do batch inference.
 | |
| class LLMPredictor:
 | |
| 
 | |
|     def __init__(self):
 | |
|         # Create an LLM.
 | |
|         self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
 | |
|                        tensor_parallel_size=tensor_parallel_size)
 | |
| 
 | |
|     def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
 | |
|         # Generate texts from the prompts.
 | |
|         # The output is a list of RequestOutput objects that contain the prompt,
 | |
|         # generated text, and other information.
 | |
|         outputs = self.llm.generate(batch["text"], sampling_params)
 | |
|         prompt: List[str] = []
 | |
|         generated_text: List[str] = []
 | |
|         for output in outputs:
 | |
|             prompt.append(output.prompt)
 | |
|             generated_text.append(' '.join([o.text for o in output.outputs]))
 | |
|         return {
 | |
|             "prompt": prompt,
 | |
|             "generated_text": generated_text,
 | |
|         }
 | |
| 
 | |
| 
 | |
| # Read one text file from S3. Ray Data supports reading multiple files
 | |
| # from cloud storage (such as JSONL, Parquet, CSV, binary format).
 | |
| ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
 | |
| 
 | |
| 
 | |
| # For tensor_parallel_size > 1, we need to create placement groups for vLLM
 | |
| # to use. Every actor has to have its own placement group.
 | |
| def scheduling_strategy_fn():
 | |
|     # One bundle per tensor parallel worker
 | |
|     pg = ray.util.placement_group(
 | |
|         [{
 | |
|             "GPU": 1,
 | |
|             "CPU": 1
 | |
|         }] * tensor_parallel_size,
 | |
|         strategy="STRICT_PACK",
 | |
|     )
 | |
|     return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
 | |
|         pg, placement_group_capture_child_tasks=True))
 | |
| 
 | |
| 
 | |
| resources_kwarg: Dict[str, Any] = {}
 | |
| if tensor_parallel_size == 1:
 | |
|     # For tensor_parallel_size == 1, we simply set num_gpus=1.
 | |
|     resources_kwarg["num_gpus"] = 1
 | |
| else:
 | |
|     # Otherwise, we have to set num_gpus=0 and provide
 | |
|     # a function that will create a placement group for
 | |
|     # each instance.
 | |
|     resources_kwarg["num_gpus"] = 0
 | |
|     resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
 | |
| 
 | |
| # Apply batch inference for all input data.
 | |
| ds = ds.map_batches(
 | |
|     LLMPredictor,
 | |
|     # Set the concurrency to the number of LLM instances.
 | |
|     concurrency=num_instances,
 | |
|     # Specify the batch size for inference.
 | |
|     batch_size=32,
 | |
|     **resources_kwarg,
 | |
| )
 | |
| 
 | |
| # Peek first 10 results.
 | |
| # NOTE: This is for local testing and debugging. For production use case,
 | |
| # one should write full result out as shown below.
 | |
| outputs = ds.take(limit=10)
 | |
| for output in outputs:
 | |
|     prompt = output["prompt"]
 | |
|     generated_text = output["generated_text"]
 | |
|     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 | |
| 
 | |
| # Write inference output data out as Parquet files to S3.
 | |
| # Multiple files would be written to the output destination,
 | |
| # and each task would write one or more files separately.
 | |
| #
 | |
| # ds.write_parquet("s3://<your-output-bucket>")
 |