mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Docs] Improve docs for RLHF co-location example (#20599)
Signed-off-by: Ricardo Decal <rdecal@anyscale.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -1,14 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
a simple demonstration to show how to co-locate
|
||||
vLLM worker with training actors on the same GPUs,
|
||||
for RLHF-like applications.
|
||||
The key points:
|
||||
- Control the placement of the vLLM workers with Ray, by setting
|
||||
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
|
||||
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
|
||||
multiple processes on the same GPU.
|
||||
Demonstrates how to co-locate a vLLM inference worker and training
|
||||
actors on the same set of GPUs for reinforcement learning from human feedback
|
||||
(RLHF) workloads.
|
||||
|
||||
Ray serves as the distributed execution framework in this example. Ray
|
||||
placement groups allocate both training actors and vLLM workers to the
|
||||
same GPU bundles, enabling fast, in-GPU communication between the two
|
||||
components.
|
||||
|
||||
The script shows how to do the following:
|
||||
|
||||
* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
|
||||
`VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
|
||||
devices.
|
||||
* Exchange tensors between processes by means of CUDA inter-process
|
||||
communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
|
||||
when multiple processes share a single GPU.
|
||||
|
||||
Note that this example assumes a single-node cluster with four GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
|
||||
its initialization for memory profiling. Residual GPU activity interferes
|
||||
with vLLM memory profiling and causes unexpected behavior.
|
||||
|
||||
Learn more about Ray placement groups:
|
||||
https://docs.ray.io/en/latest/placement-groups.html
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -22,13 +39,24 @@ from vllm import LLM
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
def __init__(self, *args, bundle_indices: list, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
"""Configure the vLLM worker for Ray placement group execution.
|
||||
|
||||
The constructor sets environment variables that allow multiple vLLM
|
||||
workers to share a single physical GPU and that encode the bundle
|
||||
indices assigned by the placement group.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments forwarded to `vllm.LLM`.
|
||||
bundle_indices (list[int]): Placement-group bundle indices
|
||||
assigned to this worker.
|
||||
**kwargs: Keyword arguments forwarded to `vllm.LLM`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, bundle_indices: list[int], **kwargs):
|
||||
# Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
|
||||
# so that vLLM can its own device placement inside the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
# every worker will use 0.4 GPU, so that we can schedule
|
||||
# 2 instances on the same GPUs.
|
||||
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
|
||||
print(f"creating LLM with bundle_indices={bundle_indices}")
|
||||
@ -36,17 +64,25 @@ class MyLLM(LLM):
|
||||
|
||||
|
||||
class RayTrainingActor:
|
||||
"""Training actor that hosts a Facebook OPT-125M model from Hugging Face.
|
||||
|
||||
The model is loaded onto the first GPU assigned to this actor, and expose
|
||||
the CUDA IPC handles so that colocated vLLM workers can map tensors
|
||||
directly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
self.model.to("cuda:0")
|
||||
# Zero out all the parameters.
|
||||
for name, p in self.model.named_parameters():
|
||||
p.data.zero_()
|
||||
torch.cuda.synchronize()
|
||||
# the argument for get_device_uuid is the index
|
||||
# of the GPU in the visible devices.
|
||||
# The argument for `get_device_uuid` is the index of the GPU in the
|
||||
# list of visible devices.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(0)
|
||||
@ -59,23 +95,23 @@ class RayTrainingActor:
|
||||
|
||||
data = {}
|
||||
for name, p in self.model.named_parameters():
|
||||
# the training actor might only have a subset of the weights
|
||||
# and need to all-gather the weights from all the actors.
|
||||
# for demonstration, here we assume all training actors have
|
||||
# the full weights.
|
||||
# A training actor might hold only a subset of the weights and may
|
||||
# need to gather weights from other actors. For demonstration
|
||||
# purposes, each training actor owns the full weight set.
|
||||
data[name] = reduce_tensor(p.detach())
|
||||
return {self.device_uuid: data}
|
||||
|
||||
|
||||
# ray manages 4 GPUs
|
||||
# Ray manages four GPUs.
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
ray.init()
|
||||
|
||||
# we want to co-locate vLLM instance and the training actor
|
||||
# on the same set of GPUs.
|
||||
# the placement plan is as follows:
|
||||
# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2)
|
||||
# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2)
|
||||
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
||||
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
||||
# (tensor parallelism = 2).
|
||||
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
||||
# (tensor parallelism = 2).
|
||||
|
||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||
ray.get(pg.ready())
|
||||
@ -104,10 +140,8 @@ for bundle_index, training_actor in enumerate(training_actors):
|
||||
training_actor_device_ids.append(device_id)
|
||||
|
||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
# IMPORTANT: when creating vLLM instances, we need to
|
||||
# make sure there are no GPU activities on the target GPUs,
|
||||
# otherwise, they will interfere with the vLLM memory profiling,
|
||||
# and cause unexpected behaviors.
|
||||
# Use the following syntax instead of the @ray.remote decorator so that
|
||||
# the placement group is customized for each bundle.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
@ -125,8 +159,8 @@ for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
bundle_indices=bundle_indices,
|
||||
)
|
||||
inference_engines.append(llm)
|
||||
# don't call any method on the inference engine here,
|
||||
# otherwise it will block until the vLLM instance is created.
|
||||
# Do not call any method on the inference engine at this point; the call
|
||||
# blocks until the vLLM instance finishes initialization.
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
@ -134,26 +168,25 @@ for i, llm in enumerate(inference_engines):
|
||||
)
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
|
||||
# check the placement
|
||||
# the first two training actors should be
|
||||
# on the same GPUs as the first inference engine
|
||||
# Verify placement: the first two training actors share the same GPUs as
|
||||
# the first inference engine.
|
||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# the last two training actors should be
|
||||
# on the same GPUs as the second inference engine
|
||||
# Verify placement: the last two training actors share the same GPUs as
|
||||
# the second inference engine.
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||
|
||||
print("gather all the IPC handles from the training actors")
|
||||
print("Gather all the IPC handles from the training actors.")
|
||||
ipc_handles = {}
|
||||
for actor in training_actors:
|
||||
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
|
||||
|
||||
print("update the weights of the inference engines")
|
||||
print("Update the weights of the inference engines.")
|
||||
for llm in inference_engines:
|
||||
ray.get(
|
||||
llm.collective_rpc.remote(
|
||||
"update_weights_from_ipc_handles", args=(ipc_handles,)
|
||||
)
|
||||
)
|
||||
print("check if the weights are updated")
|
||||
print("Check if the weights are updated.")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
|
Reference in New Issue
Block a user