mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
### What does this PR do? To implement a purely asynchronous training workflow, we further split the training process into a Trainer and a Rollouter based on the existing one-step-off policy code, with samples transmitted via a message queue. We will continue to integrate partial rollout to mitigate the impact of long-tail training. > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. https://github.com/volcengine/verl/pull/2231 https://github.com/volcengine/verl/pull/2200 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: meituan-search <machi04@meituan.com> Co-authored-by: wangshulin02 <wangshulin02@meituan.com> Co-authored-by: arron <arron@MBP-2G17FXQ05P-2332.local> Co-authored-by: wangshulin02 <953550366@qq.com> Co-authored-by: hadoop-ai-search <hadoop-ai-search@set-zw04-mlp-codelab-pc1189.mt> Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com> Co-authored-by: arron <arron@MBP-VH9RV7LTJC-1907.local> Co-authored-by: arron <arron@MBP-JFQXPWR11F-1943.local>
155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Optional, Sequence
|
|
|
|
import ray
|
|
from ray.actor import ActorHandle
|
|
from vllm import SamplingParams
|
|
from vllm.inputs import TokensPrompt
|
|
from vllm.outputs import RequestOutput
|
|
|
|
from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig
|
|
from verl.workers.rollout.replica import RolloutMode
|
|
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
|
|
_qwen2_5_vl_dedup_image_tokens,
|
|
vLLMHttpServerBase,
|
|
vLLMReplica,
|
|
)
|
|
|
|
logger = logging.getLogger(__file__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
@ray.remote(num_cpus=1)
|
|
class vLLMHttpServerForPartial(vLLMHttpServerBase):
|
|
def __init__(
|
|
self,
|
|
config: RolloutConfig | RewardModelConfig,
|
|
model_config: HFModelConfig,
|
|
rollout_mode: RolloutMode,
|
|
workers: list[ActorHandle],
|
|
replica_rank: int,
|
|
node_rank: int,
|
|
gpus_per_node: int,
|
|
nnodes: int,
|
|
):
|
|
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
|
|
|
|
# for cancel LLMServer
|
|
self.paused = False
|
|
self.lock = asyncio.Lock()
|
|
self.cancel_event: dict[str, asyncio.Event] = {}
|
|
self.req_output: dict[str, Optional[RequestOutput]] = {}
|
|
|
|
async def _generate_step(
|
|
self,
|
|
prompt_ids: list[int],
|
|
sampling_params: dict[str, Any],
|
|
request_id: str,
|
|
image_data: Optional[list[Any]] = None,
|
|
):
|
|
max_tokens = self.config.max_model_len - len(prompt_ids)
|
|
sampling_params["logprobs"] = 1
|
|
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
|
|
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
|
|
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
|
|
prompt = TokensPrompt(
|
|
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
|
|
)
|
|
generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)
|
|
|
|
# Get final response
|
|
self.req_output[request_id]: Optional[RequestOutput] = None
|
|
async for output in generator:
|
|
self.req_output[request_id] = output
|
|
assert self.req_output[request_id] is not None
|
|
|
|
async def generate_for_partial(
|
|
self,
|
|
prompt_ids: list[int],
|
|
sampling_params: dict[str, Any],
|
|
request_id: str,
|
|
image_data: Optional[list[Any]] = None,
|
|
) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:
|
|
async with self.lock:
|
|
if self.paused:
|
|
# After cancel, all tasks will return directly and wait for the next submission
|
|
return [], [], True
|
|
self.cancel_event[request_id] = asyncio.Event()
|
|
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
|
|
generation_handle = asyncio.create_task(
|
|
self._generate_step(prompt_ids, sampling_params, request_id, image_data)
|
|
)
|
|
|
|
done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
for task in done:
|
|
await task
|
|
|
|
for task in pend:
|
|
task.cancel()
|
|
|
|
async with self.lock:
|
|
token_ids = self.req_output[request_id].outputs[0].token_ids
|
|
log_probs: list[float] = []
|
|
for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):
|
|
# In sampling_params, logprobs is set to 1, which should return 1,
|
|
# but in practice there are multiple. Take the log_prob corresponding to token_id
|
|
token_id = self.req_output[request_id].outputs[0].token_ids[i]
|
|
log_probs.append(x[token_id].logprob)
|
|
is_cancel = generation_handle not in done
|
|
self.cancel_event.pop(request_id, None)
|
|
self.req_output.pop(request_id, None)
|
|
return token_ids, log_probs, is_cancel
|
|
|
|
async def cancel(self):
|
|
async with self.lock:
|
|
self.paused = True
|
|
for request_id in self.cancel_event:
|
|
self.cancel_event[request_id].set()
|
|
|
|
async def resume(self):
|
|
async with self.lock:
|
|
self.paused = False
|
|
|
|
async def reset_prefix_cache(self):
|
|
async with self.lock:
|
|
await self.engine.reset_prefix_cache()
|
|
|
|
|
|
class FullyAsyncvLLMReplica(vLLMReplica):
|
|
def __init__(
|
|
self,
|
|
replica_rank: int,
|
|
config: RolloutConfig | RewardModelConfig,
|
|
model_config: HFModelConfig,
|
|
gpus_per_node: int = 8,
|
|
):
|
|
super().__init__(replica_rank, config, model_config, gpus_per_node)
|
|
self.server_class = vLLMHttpServerForPartial
|
|
|
|
async def cancel(self):
|
|
"""Cancel each rollout server."""
|
|
await asyncio.gather(*[server.cancel.remote() for server in self.servers])
|
|
|
|
async def resume(self):
|
|
"""Resume each rollout server."""
|
|
await asyncio.gather(*[server.resume.remote() for server in self.servers])
|
|
|
|
async def reset_prefix_cache(self):
|
|
"""reset kv cache in each rollout server."""
|
|
await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])
|