[Docs] Improve docs for RLHF co-location example (#20599)

Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Ricardo Decal
2025-07-09 08:06:43 -07:00
committed by GitHub
parent 9ff2af6d2b
commit 853487bc1b

View File

@ -1,14 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
a simple demonstration to show how to co-locate
vLLM worker with training actors on the same GPUs,
for RLHF-like applications.
The key points:
- Control the placement of the vLLM workers with Ray, by setting
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
Demonstrates how to co-locate a vLLM inference worker and training
actors on the same set of GPUs for reinforcement learning from human feedback
(RLHF) workloads.
Ray serves as the distributed execution framework in this example. Ray
placement groups allocate both training actors and vLLM workers to the
same GPU bundles, enabling fast, in-GPU communication between the two
components.
The script shows how to do the following:
* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
`VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
devices.
* Exchange tensors between processes by means of CUDA inter-process
communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
when multiple processes share a single GPU.
Note that this example assumes a single-node cluster with four GPUs, but Ray
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
its initialization for memory profiling. Residual GPU activity interferes
with vLLM memory profiling and causes unexpected behavior.
Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import os
@ -22,13 +39,24 @@ from vllm import LLM
class MyLLM(LLM):
def __init__(self, *args, bundle_indices: list, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
"""Configure the vLLM worker for Ray placement group execution.
The constructor sets environment variables that allow multiple vLLM
workers to share a single physical GPU and that encode the bundle
indices assigned by the placement group.
Args:
*args: Positional arguments forwarded to `vllm.LLM`.
bundle_indices (list[int]): Placement-group bundle indices
assigned to this worker.
**kwargs: Keyword arguments forwarded to `vllm.LLM`.
"""
def __init__(self, *args, bundle_indices: list[int], **kwargs):
# Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
# so that vLLM can its own device placement inside the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# every worker will use 0.4 GPU, so that we can schedule
# 2 instances on the same GPUs.
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}")
@ -36,17 +64,25 @@ class MyLLM(LLM):
class RayTrainingActor:
"""Training actor that hosts a Facebook OPT-125M model from Hugging Face.
The model is loaded onto the first GPU assigned to this actor, and expose
the CUDA IPC handles so that colocated vLLM workers can map tensors
directly.
"""
def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
# Zero out all the parameters.
for name, p in self.model.named_parameters():
p.data.zero_()
torch.cuda.synchronize()
# the argument for get_device_uuid is the index
# of the GPU in the visible devices.
# The argument for `get_device_uuid` is the index of the GPU in the
# list of visible devices.
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0)
@ -59,23 +95,23 @@ class RayTrainingActor:
data = {}
for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
# A training actor might hold only a subset of the weights and may
# need to gather weights from other actors. For demonstration
# purposes, each training actor owns the full weight set.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
# ray manages 4 GPUs
# Ray manages four GPUs.
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
# we want to co-locate vLLM instance and the training actor
# on the same set of GPUs.
# the placement plan is as follows:
# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2)
# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2)
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
@ -104,10 +140,8 @@ for bundle_index, training_actor in enumerate(training_actors):
training_actor_device_ids.append(device_id)
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# IMPORTANT: when creating vLLM instances, we need to
# make sure there are no GPU activities on the target GPUs,
# otherwise, they will interfere with the vLLM memory profiling,
# and cause unexpected behaviors.
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
@ -125,8 +159,8 @@ for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
bundle_indices=bundle_indices,
)
inference_engines.append(llm)
# don't call any method on the inference engine here,
# otherwise it will block until the vLLM instance is created.
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
@ -134,26 +168,25 @@ for i, llm in enumerate(inference_engines):
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# check the placement
# the first two training actors should be
# on the same GPUs as the first inference engine
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the last two training actors should be
# on the same GPUs as the second inference engine
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("gather all the IPC handles from the training actors")
print("Gather all the IPC handles from the training actors.")
ipc_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
print("update the weights of the inference engines")
print("Update the weights of the inference engines.")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote(
"update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
print("check if the weights are updated")
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))