Files
verl/recipe/fully_async_policy/fully_async_rollouter.py
arron b25bb7d4f3 [trainer, recipe] feat: fully async training recipe (#2981)
### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. https://github.com/volcengine/verl/pull/2231
https://github.com/volcengine/verl/pull/2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <machi04@meituan.com>
Co-authored-by: wangshulin02 <wangshulin02@meituan.com>
Co-authored-by: arron <arron@MBP-2G17FXQ05P-2332.local>
Co-authored-by: wangshulin02 <953550366@qq.com>
Co-authored-by: hadoop-ai-search <hadoop-ai-search@set-zw04-mlp-codelab-pc1189.mt>
Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com>
Co-authored-by: arron <arron@MBP-VH9RV7LTJC-1907.local>
Co-authored-by: arron <arron@MBP-JFQXPWR11F-1943.local>
2025-10-17 22:29:18 +08:00

647 lines
26 KiB
Python

# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
from pprint import pformat
import ray
from ray import ObjectRef
from recipe.fully_async_policy.detach_utils import (
RolloutSample,
ValidateMetrics,
merge_rollout_sample,
prepare_single_generation_data,
)
from recipe.fully_async_policy.message_queue import MessageQueueClient
from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType
from verl.utils.profiler import marked_timer
from verl.utils.tracking import ValidationGenerationsLogger
@ray.remote(num_cpus=10, max_concurrency=100)
class FullyAsyncRollouter(FullyAsyncRayPPOTrainer):
"""
Asynchronous sample generator, responsible for continuously generating training samples
and putting them into MessageQueue
Based on the mature implementation improvements of OneStepOffRayTrainer
"""
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
device_name=None,
):
# Store the tokenizer for text processing
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert not self.hybrid_engine
assert self.config.data.train_batch_size == 0, "train_batch_size must be zero"
assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one"
assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0"
assert self.config.async_training.trigger_parameter_sync_step >= 1, (
"trigger_parameter_sync_step must larger than 1"
)
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
)
self.ref_in_actor = False
self.kl_ctrl_in_reward = False
self.use_critic = False
self.use_reference_policy = False
self.use_rm = False
print("[FullyAsyncRollouter] Creating datasets...")
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.utils.dataset.rl_dataset import collate_fn
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
self._validate_config()
print(f"[FullyAsyncRollouter] Rollouter _create_dataloader...\n{train_dataset}\n{val_dataset}")
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
# ==================== fully async config ====================
self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.rollout.total_rollout_steps is not None:
self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps)
print(f"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}")
self.total_train_steps = None
# Rollouter parameter configuration
self.message_queue_client = None
# Worker groups: rollout_wg is same to actor_rollout_wg
self.rollout_wg = None
self.actor_rollout_wg = None
self.async_rollout_manager = None
# Config
self.staleness_threshold: float = config.async_training.get("staleness_threshold", 1)
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
self.require_batches = config.async_training.require_batches
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
self.max_required_samples = None
self.max_concurrent_samples = None
# queue size
self.max_queue_size = None
# Statistics
self.current_param_version = 0
self.total_generated_samples = 0
self.staleness_samples = 0
self.dropped_stale_samples = 0
self.processed_sample_count = 0
self.global_steps = 0
self.idle_start_time = None
self.version_start_time = None
# Concurrency control
# Modified by self.pause() or self._should_pause_generation()
self.paused = False
self.running = True
self.monitor_loop_trigger = True
# Initialize async locks directly
self.lock = asyncio.Lock()
self.condition = asyncio.Condition(self.lock)
# Initialize async queues
self.pending_queue = asyncio.Queue(maxsize=128)
self.active_tasks = set()
self.result_queue = asyncio.Queue()
self.cancel_queue = asyncio.Queue()
async def set_message_queue_client(self, message_queue_client: MessageQueueClient):
"""Set message queue client"""
async with self.lock:
self.message_queue_client = message_queue_client
async def set_max_required_samples(self):
async with self.lock:
self.max_required_samples = int(
self.required_samples
* (self.staleness_threshold + 1)
* self.config.async_training.trigger_parameter_sync_step
)
self.total_train_steps = int(
self.total_rollout_steps
/ (self.required_samples * self.config.async_training.trigger_parameter_sync_step)
)
self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16
self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples)
self.max_queue_size = self.max_required_samples
print(
f"[FullyAsyncRollouter] required_samples : {self.required_samples} "
f"max_required_samples: {self.max_required_samples} "
f"max_queue_size: {self.max_queue_size} "
f"total_train_steps: {self.total_train_steps} "
f"total_rollout_steps: {self.total_rollout_steps} "
f"max_concurrent_samples: {self.max_concurrent_samples} "
)
def get_rollout_wg(self):
"""Get rollout worker group"""
return self.rollout_wg
def get_max_queue_size(self):
return self.max_queue_size
def get_total_train_steps(self):
return self.total_train_steps
async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0):
"""Update current parameter version"""
async with self.lock:
old_version = self.current_param_version
self.current_param_version = version
# every time param change, reset staleness_samples
self.staleness_samples = (
len(self.active_tasks)
+ self.result_queue.qsize()
+ self.cancel_queue.qsize()
+ await self.message_queue_client.get_queue_size()
)
timing_raw = {}
idle_ratio = None
if self.idle_start_time is not None and self.version_start_time is not None:
rollout_active_time = self.idle_start_time - self.version_start_time
rollout_version_time = time.time() - self.version_start_time
idle_ratio = 1 - rollout_active_time / rollout_version_time
timing_raw["rollouter/active_time"] = rollout_active_time
timing_raw["rollouter/version_time"] = rollout_version_time
timing_raw["rollouter/idle_ratio"] = idle_ratio
self.idle_start_time = None
print(
f"[FullyAsyncRollouter][Public][update_param_version] "
f"Parameter version updated from {old_version} to {version} "
f",reset staleness_samples to: {self.staleness_samples}"
f",idle_ratio: {idle_ratio}"
)
val_metrics = None
if (
self.val_reward_fn is not None
and self.config.rollout.test_freq > 0
and self.current_param_version % self.config.rollout.test_freq == 0
and self.current_param_version > 0 # don't test here in the initial parameter sync
) or (validate and self.val_reward_fn is not None):
with marked_timer("rollouter/validate_time", timing_raw, color="green"):
val_metrics: dict = self._validate()
data = ValidateMetrics(
timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version
)
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
self.version_start_time = time.time()
def _validate_config(self):
# Validate asynchronous training configuration
if not hasattr(self.config, "async_training"):
raise ValueError("[FullyAsyncRollouter] Missing async_training configuration")
assert self.config.actor_rollout_ref.rollout.calculate_log_probs, "must rollout calculate log_probs"
async def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self._init_resource_pools()
self._create_worker_classes()
self._init_worker_groups()
self._init_models()
await self._init_async_rollout_manager()
def _create_actor_rollout_classes(self):
# only create rollout
for role in [Role.Rollout]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
role_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[role],
config=self.config.actor_rollout_ref,
role=str(role),
)
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
def _init_models(self):
self.rollout_wg = self.all_wg[str(Role.Rollout)]
self.rollout_wg.init_model()
self.actor_rollout_wg = self.rollout_wg
def _create_continuous_iterator(self):
"""
Create a continuous data iterator across epoch
"""
for epoch in range(self.config.rollout.total_epochs):
iterator = iter(self.train_dataloader)
for batch_dict in iterator:
yield epoch, batch_dict
async def _init_async_rollout_manager(self):
# create async rollout manager and request scheduler
assert self.config.actor_rollout_ref.rollout.mode == "async"
from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
self.async_rollout_mode = True
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
config=self.config,
worker_group=self.rollout_wg,
)
# Add samples to the pending_queue
async def _feed_samples(self):
continuous_iterator = self._create_continuous_iterator()
for epoch, batch_dict in continuous_iterator:
# Similar to _prepare_generate_batch: Separate data
full_batch = prepare_single_generation_data(
batch_dict, self.global_steps, self.config.actor_rollout_ref.rollout.n
)
sample_id = f"sample_{epoch}_{self.global_steps}"
rollout_sample = RolloutSample(
full_batch=full_batch,
agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n,
sample_id=sample_id,
epoch=epoch,
param_version=0,
param_version_start=[],
param_version_end=[],
processing_times=[],
rollout_status={},
)
await self.pending_queue.put(rollout_sample)
# Check if have reached the last step
if self.global_steps >= self.total_rollout_steps:
print(
f"[FullyAsyncRollouter][Feed] "
f"Maximum count has been reached, stop adding new samples"
f"{self.global_steps} >= {self.total_rollout_steps}"
)
break
self.global_steps += 1
# End signal
await self.pending_queue.put("DONE")
print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added")
async def _processor_worker(self):
"""
Streaming worker coroutines, a sample is submitted for processing without waiting for batches
"""
while True:
if self.paused or await self._should_pause_generation():
print(
"[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..."
)
async with self.lock:
self.paused = True
while self.active_tasks:
async with self.lock:
# After acquiring the lock, the number of active_tasks may change, need to be verified again
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
async with self.lock:
while self.paused:
self.idle_start_time = time.time()
await self.condition.wait()
continue
simple_from_cancel_queue = False
if not self.cancel_queue.empty():
rollout_sample = await self.cancel_queue.get()
simple_from_cancel_queue = True
else:
rollout_sample = await self.pending_queue.get()
self.staleness_samples += 1
if rollout_sample == "DONE":
print(
"[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..."
)
while self.active_tasks:
async with self.lock:
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
break
# Check whether the number of concurrent tasks exceeds the limit
while len(self.active_tasks) >= self.max_concurrent_samples:
async with self.lock:
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
# Submit single sample processing
async with self.lock:
# After the pause is over, the lock is acquired and it is necessary
# to determine whether it is the pause phase, otherwise continue to wait
while self.paused:
await self.condition.wait()
task = asyncio.create_task(
self._process_single_sample_streaming(rollout_sample),
name=rollout_sample.sample_id,
)
self.active_tasks.add(task)
if simple_from_cancel_queue:
self.cancel_queue.task_done()
else:
self.pending_queue.task_done()
async def _process_single_sample_streaming(self, rollout_sample: RolloutSample):
"""Process a single sample streamingly"""
# Calling asynchronous generation methods
rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len(
rollout_sample.full_batch
)
agent_loop_output_list = await self.async_rollout_manager.generate_single_sample_async(
rollout_sample.full_batch, rollout_sample.agent_loop_output_list
)
rollout_sample.agent_loop_output_list = agent_loop_output_list
is_cancel = False
for agent_loop in agent_loop_output_list:
if not is_cancel and agent_loop.is_cancel:
is_cancel = True
if is_cancel:
# Put in the cancel queue and wait for the generation to resume
await self.cancel_queue.put(rollout_sample)
else:
# put into the result_queue
rollout_sample.param_version = self.current_param_version
rollout_sample.rollout_status = await self.get_statistics()
await self.result_queue.put(rollout_sample)
self.processed_sample_count += 1
async def _consumer_worker(self):
"""
The consumer coroutine is responsible for obtaining the processing results
from the result queue and putting them into the message queue
"""
while True:
rollout_sample = await self.result_queue.get()
rollout_sample = merge_rollout_sample(self.config, self.tokenizer, rollout_sample)
# Put RolloutSample into the message queue
success = await self.message_queue_client.put_sample(
sample=ray.cloudpickle.dumps(rollout_sample),
param_version=rollout_sample.param_version,
)
if success:
self.total_generated_samples += 1
else:
self.dropped_stale_samples += 1
self.result_queue.task_done()
async def _streaming_generation_main(self):
"""The main entry method for stream processing"""
# we start from step 1
self.global_steps += 1
if self.async_rollout_manager is None:
await self._init_async_rollout_manager()
# Start the streaming loop
print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}")
# Start sample feed coroutine, streaming process coroutine and consumer coroutine
self.feed_task = asyncio.create_task(self._feed_samples())
self.processor_task = asyncio.create_task(self._processor_worker())
self.consumer_task = asyncio.create_task(self._consumer_worker())
try:
# Wait for sample feed to complete
await self.feed_task
print("[FullyAsyncRollouter] Sample feed completed")
# Wait for streaming to complete
await self.processor_task
print("[FullyAsyncRollouter] Streaming process completed")
# Waiting for the result queue to clear
await self.result_queue.join()
print("[FullyAsyncRollouter] Result queue cleared")
except Exception as e:
print(f"[FullyAsyncRollouter] Streaming process exception:{e}")
finally:
if self.processor_task:
self.processor_task.cancel()
if self.consumer_task:
self.consumer_task.cancel()
await asyncio.gather(self.processor_task, self.consumer_task, return_exceptions=True)
# Send a finish signal
await self.message_queue_client.put_sample(
sample=None,
param_version=self.current_param_version,
)
async with self.lock:
self.running = False
async def fit(self):
"""
Start the async rollouter - entry point that sets up and runs async tasks
Main async fit method that coordinates all coroutines
"""
print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...")
if self.message_queue_client is None:
raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.")
# Set the running status flag
async with self.lock:
self.paused = False
self.running = True
# Create the main asynchronous task
generation_task = asyncio.create_task(self._streaming_generation_main())
monitor_task = asyncio.create_task(self._async_monitor_loop())
try:
# Run build and monitoring tasks concurrently
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
except Exception as e:
print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}")
finally:
if not generation_task.done():
generation_task.cancel()
if not monitor_task.done():
monitor_task.cancel()
# Wait for the task to complete
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
print("[FullyAsyncRollouter] Rollouter fit completed")
async def _async_monitor_loop(self):
"""
Async coroutine for monitoring:
Function 1: Log information output
Function 2: Trigger rollout recovery
"""
last_stats_time = time.time()
stats_interval = 60.0
check_interval = 10.0
while True:
async with self.lock:
if not self.running:
break
await asyncio.sleep(check_interval)
# Print statistics periodically
current_time = time.time()
if current_time - last_stats_time >= stats_interval:
stats = await self.get_statistics()
print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}")
last_stats_time = current_time
# Trigger rollout recovery
if self.monitor_loop_trigger:
if not await self._should_pause_generation():
async with self.lock:
self.paused = False
self.condition.notify_all()
async def _should_pause_generation(self) -> bool:
"""Determine whether the build should be paused"""
queue_stats = self.message_queue_client.get_statistics_sync()
queue_size = queue_stats["queue_size"]
if queue_size >= self.max_queue_size:
if not self.paused:
print(
f"[FullyAsyncRollouter][ShouldPause] "
f"due to full queue: size={queue_size}, max={self.max_queue_size}"
)
return True
if self.staleness_samples >= self.max_required_samples:
if not self.paused:
print(
"[FullyAsyncRollouter][ShouldPause] "
f"due to "
f"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} "
)
return True
return False
async def pause(self):
"""pause rollout"""
print("[FullyAsyncRollouter][Public][Pause]")
async with self.lock:
self.paused = True
# Cancel all rollout tasks
if self.config.async_training.partial_rollout:
await self.async_rollout_manager.cancel()
if self.active_tasks:
await asyncio.gather(*self.active_tasks, return_exceptions=True)
self.active_tasks.clear()
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
await self.async_rollout_manager.reset_prefix_cache()
self.monitor_loop_trigger = False
async def resume(self, dependency_ref: ObjectRef = None):
if dependency_ref is not None:
ray.get(dependency_ref)
print("[FullyAsyncRollouter][Public][Resume]")
async with self.lock:
self.paused = False
self.monitor_loop_trigger = True
self.condition.notify_all()
if self.config.async_training.partial_rollout:
await self.async_rollout_manager.resume()
async def get_statistics(self) -> dict:
queue_stats = self.message_queue_client.get_statistics_sync()
stats = {
# monitor stats
"monitor/active_tasks_size": len(self.active_tasks),
"monitor/queue/pending_queue_size": self.pending_queue.qsize(),
"monitor/queue/cancel_queue_size": self.cancel_queue.qsize(),
"monitor/queue/result_queue_size": self.result_queue.qsize(),
"monitor/queue/mq_queue_size": queue_stats["queue_size"],
# counting stats
"count/current_param_version": self.current_param_version,
"count/total_generated_samples": self.total_generated_samples,
"count/staleness_samples": self.staleness_samples,
"count/dropped_stale_samples": self.dropped_stale_samples,
# static stats
"static/max_required_samples": self.max_required_samples,
"static/required_samples": self.required_samples,
"static/staleness_threshold": self.staleness_threshold,
"static/max_queue_size": self.max_queue_size,
"static/max_concurrent_samples": self.max_concurrent_samples,
}
return stats