Compare commits

...

2 Commits

Author SHA1 Message Date
49a5febd58 json format fixed 2025-06-25 12:43:55 +00:00
2eef438699 spec-dec-config 2025-06-24 16:00:36 +00:00
2 changed files with 64 additions and 5 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
import argparse
import json
import logging
import os
from collections.abc import Sequence
@ -67,6 +68,14 @@ logger = logging.getLogger(__name__)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def json_dict_type(arg_string):
"""Parse a JSON string into a dictionary."""
try:
return json.loads(arg_string)
except json.JSONDecodeError as e:
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
class WeightSyncWorkerExtension:
"""
A vLLM worker extension that enables weight synchronization between a client and multiple server workers.
@ -174,10 +183,8 @@ class ScriptArguments:
dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
speculative_config (`str`, *optional*, defaults to `None`):
Speculative decoding configuration for vLLM as JSON string. If set, the spec model will be used for generation. Example: '{"model": "Qwen/Qwen3-0.6B"}'
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
this feature.
@ -233,6 +240,12 @@ class ScriptArguments:
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
speculative_config: Optional[str] = field(
default=None,
metadata={
"help": "Speculative decoding configuration for vLLM as JSON string. If set, the spec model will be used for generation. Example: '{\"model\": \"Qwen/Qwen3-0.6B\"}'"
},
)
max_model_len: Optional[int] = field(
default=None,
metadata={
@ -277,6 +290,14 @@ class ScriptArguments:
},
)
def __post_init__(self):
"""Parse JSON strings into appropriate types."""
if self.speculative_config is not None:
try:
self.speculative_config = json.loads(self.speculative_config)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in speculative_config: {e}")
def llm_worker(
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection
@ -294,13 +315,14 @@ def llm_worker(
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
speculative_config=script_args.speculative_config,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
#worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
trust_remote_code=script_args.trust_remote_code,
)

37
vllm_spec.py Normal file
View File

@ -0,0 +1,37 @@
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="spec",
use_vllm=True,
vllm_mode="server",
vllm_server_port=8001,
bf16=True,
gradient_checkpointing=True,
per_device_train_batch_size=2,
num_generations=2,
max_completion_length=32,
max_steps=10,
)
trainer = GRPOTrainer(
model="microsoft/DialoGPT-medium",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
# Test commands:
# 1. Start vLLM server with speculative decoding:
# python trl/scripts/vllm_serve.py --model microsoft/DialoGPT-medium --speculative_config '{"model": "microsoft/DialoGPT-small", "num_speculative_tokens": 5}' --port 8001
#
# 2. Run training with speculative decoding:
# CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch vllm_spec.py