Compare commits

...

1 Commits

Author SHA1 Message Date
205ef41f3b [rollout] feat: pass all dataset fields to agent loop run 2025-07-30 13:21:43 +08:00
3 changed files with 16 additions and 18 deletions

View File

@ -183,12 +183,12 @@ class AgentLoopBase(ABC):
cls._class_initialized = True
@abstractmethod
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
"""Run agent loop to interact with LLM server and environment.
Args:
messages (List[Dict[str, Any]]): Input messages.
sampling_params (Dict[str, Any]): LLM sampling params.
**kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.
Returns:
AgentLoopOutput: Agent loop output.
@ -285,25 +285,19 @@ class AgentLoopWorker:
if "agent_name" not in batch.non_tensor_batch:
batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)
tasks = []
agent_names = batch.non_tensor_batch["agent_name"]
raw_prompts = batch.non_tensor_batch["raw_prompt"]
if "index" in batch.non_tensor_batch:
index = batch.non_tensor_batch["index"]
else:
index = np.arange(len(raw_prompts))
index = np.arange(len(batch))
trajectory_info = await get_trajectory_info(
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
)
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
if not isinstance(messages, list | np.ndarray):
raise TypeError(f"messages must be a list or numpy array, got {type(messages)}")
tasks.append(
asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory))
)
tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))
outputs = await asyncio.gather(*tasks)
output = self._postprocess(outputs)
@ -311,10 +305,11 @@ class AgentLoopWorker:
async def _run_agent_loop(
self,
agent_name: str,
messages: list[dict[str, Any]],
sampling_params: dict[str, Any],
trajectory: dict[str, Any],
*,
agent_name: str,
**kwargs,
) -> _InternalAgentLoopOutput:
with rollout_trace_attr(
step=trajectory["step"],
@ -334,7 +329,7 @@ class AgentLoopWorker:
server_manager=self.server_manager,
tokenizer=self.tokenizer,
)
output = await agent_loop.run(messages, sampling_params)
output = await agent_loop.run(sampling_params, **kwargs)
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])

View File

@ -32,7 +32,9 @@ class SingleTurnAgentLoop(AgentLoopBase):
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(

View File

@ -56,7 +56,8 @@ class ToolAgentLoop(AgentLoopBase):
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
@rollout_trace_op
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(