mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
2 Commits
f6e7c200c0
...
vllm-spec-
Author | SHA1 | Date | |
---|---|---|---|
49a5febd58 | |||
2eef438699 |
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
@ -67,6 +68,14 @@ logger = logging.getLogger(__name__)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def json_dict_type(arg_string):
|
||||
"""Parse a JSON string into a dictionary."""
|
||||
try:
|
||||
return json.loads(arg_string)
|
||||
except json.JSONDecodeError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
|
||||
|
||||
|
||||
class WeightSyncWorkerExtension:
|
||||
"""
|
||||
A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
|
||||
@ -174,10 +183,8 @@ class ScriptArguments:
|
||||
dtype (`str`, *optional*, defaults to `"auto"`):
|
||||
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
||||
based on the model configuration. Find the supported values in the vLLM documentation.
|
||||
max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
||||
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
||||
speculative_config (`str`, *optional*, defaults to `None`):
|
||||
Speculative decoding configuration for vLLM as JSON string. If set, the spec model will be used for generation. Example: '{"model": "Qwen/Qwen3-0.6B"}'
|
||||
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
|
||||
this feature.
|
||||
@ -233,6 +240,12 @@ class ScriptArguments:
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
speculative_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Speculative decoding configuration for vLLM as JSON string. If set, the spec model will be used for generation. Example: '{\"model\": \"Qwen/Qwen3-0.6B\"}'"
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -277,6 +290,14 @@ class ScriptArguments:
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Parse JSON strings into appropriate types."""
|
||||
if self.speculative_config is not None:
|
||||
try:
|
||||
self.speculative_config = json.loads(self.speculative_config)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in speculative_config: {e}")
|
||||
|
||||
|
||||
def llm_worker(
|
||||
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection
|
||||
@ -294,13 +315,14 @@ def llm_worker(
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
speculative_config=script_args.speculative_config,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
#worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
trust_remote_code=script_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
37
vllm_spec.py
Normal file
37
vllm_spec.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="spec",
|
||||
use_vllm=True,
|
||||
vllm_mode="server",
|
||||
vllm_server_port=8001,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
per_device_train_batch_size=2,
|
||||
num_generations=2,
|
||||
max_completion_length=32,
|
||||
max_steps=10,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="microsoft/DialoGPT-medium",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Test commands:
|
||||
# 1. Start vLLM server with speculative decoding:
|
||||
# python trl/scripts/vllm_serve.py --model microsoft/DialoGPT-medium --speculative_config '{"model": "microsoft/DialoGPT-small", "num_speculative_tokens": 5}' --port 8001
|
||||
#
|
||||
# 2. Run training with speculative decoding:
|
||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch vllm_spec.py
|
Reference in New Issue
Block a user