Files
verl/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py
arron b25bb7d4f3 [trainer, recipe] feat: fully async training recipe (#2981)
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. https://github.com/volcengine/verl/pull/2231
https://github.com/volcengine/verl/pull/2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <machi04@meituan.com>
Co-authored-by: wangshulin02 <wangshulin02@meituan.com>
Co-authored-by: arron <arron@MBP-2G17FXQ05P-2332.local>
Co-authored-by: wangshulin02 <953550366@qq.com>
Co-authored-by: hadoop-ai-search <hadoop-ai-search@set-zw04-mlp-codelab-pc1189.mt>
Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com>
Co-authored-by: arron <arron@MBP-VH9RV7LTJC-1907.local>
Co-authored-by: arron <arron@MBP-JFQXPWR11F-1943.local>
2025-10-17 22:29:18 +08:00

155 lines
5.8 KiB
Python

# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import Any, Optional, Sequence
import ray
from ray.actor import ActorHandle
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
_qwen2_5_vl_dedup_image_tokens,
vLLMHttpServerBase,
vLLMReplica,
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
@ray.remote(num_cpus=1)
class vLLMHttpServerForPartial(vLLMHttpServerBase):
def __init__(
self,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
rollout_mode: RolloutMode,
workers: list[ActorHandle],
replica_rank: int,
node_rank: int,
gpus_per_node: int,
nnodes: int,
):
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
# for cancel LLMServer
self.paused = False
self.lock = asyncio.Lock()
self.cancel_event: dict[str, asyncio.Event] = {}
self.req_output: dict[str, Optional[RequestOutput]] = {}
async def _generate_step(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
):
max_tokens = self.config.max_model_len - len(prompt_ids)
sampling_params["logprobs"] = 1
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
prompt = TokensPrompt(
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
)
generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)
# Get final response
self.req_output[request_id]: Optional[RequestOutput] = None
async for output in generator:
self.req_output[request_id] = output
assert self.req_output[request_id] is not None
async def generate_for_partial(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:
async with self.lock:
if self.paused:
# After cancel, all tasks will return directly and wait for the next submission
return [], [], True
self.cancel_event[request_id] = asyncio.Event()
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
generation_handle = asyncio.create_task(
self._generate_step(prompt_ids, sampling_params, request_id, image_data)
)
done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED)
for task in done:
await task
for task in pend:
task.cancel()
async with self.lock:
token_ids = self.req_output[request_id].outputs[0].token_ids
log_probs: list[float] = []
for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):
# In sampling_params, logprobs is set to 1, which should return 1,
# but in practice there are multiple. Take the log_prob corresponding to token_id
token_id = self.req_output[request_id].outputs[0].token_ids[i]
log_probs.append(x[token_id].logprob)
is_cancel = generation_handle not in done
self.cancel_event.pop(request_id, None)
self.req_output.pop(request_id, None)
return token_ids, log_probs, is_cancel
async def cancel(self):
async with self.lock:
self.paused = True
for request_id in self.cancel_event:
self.cancel_event[request_id].set()
async def resume(self):
async with self.lock:
self.paused = False
async def reset_prefix_cache(self):
async with self.lock:
await self.engine.reset_prefix_cache()
class FullyAsyncvLLMReplica(vLLMReplica):
def __init__(
self,
replica_rank: int,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
gpus_per_node: int = 8,
):
super().__init__(replica_rank, config, model_config, gpus_per_node)
self.server_class = vLLMHttpServerForPartial
async def cancel(self):
"""Cancel each rollout server."""
await asyncio.gather(*[server.cancel.remote() for server in self.servers])
async def resume(self):
"""Resume each rollout server."""
await asyncio.gather(*[server.resume.remote() for server in self.servers])
async def reset_prefix_cache(self):
"""reset kv cache in each rollout server."""
await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])