mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Compare commits
1 Commits
33eb86f54f
...
wuxibin/fi
Author | SHA1 | Date | |
---|---|---|---|
205ef41f3b |
@ -183,12 +183,12 @@ class AgentLoopBase(ABC):
|
||||
cls._class_initialized = True
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||
"""Run agent loop to interact with LLM server and environment.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, Any]]): Input messages.
|
||||
sampling_params (Dict[str, Any]): LLM sampling params.
|
||||
**kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.
|
||||
|
||||
Returns:
|
||||
AgentLoopOutput: Agent loop output.
|
||||
@ -285,25 +285,19 @@ class AgentLoopWorker:
|
||||
if "agent_name" not in batch.non_tensor_batch:
|
||||
batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)
|
||||
|
||||
tasks = []
|
||||
agent_names = batch.non_tensor_batch["agent_name"]
|
||||
raw_prompts = batch.non_tensor_batch["raw_prompt"]
|
||||
if "index" in batch.non_tensor_batch:
|
||||
index = batch.non_tensor_batch["index"]
|
||||
else:
|
||||
index = np.arange(len(raw_prompts))
|
||||
index = np.arange(len(batch))
|
||||
|
||||
trajectory_info = await get_trajectory_info(
|
||||
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
|
||||
)
|
||||
|
||||
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
|
||||
if not isinstance(messages, list | np.ndarray):
|
||||
raise TypeError(f"messages must be a list or numpy array, got {type(messages)}")
|
||||
|
||||
tasks.append(
|
||||
asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory))
|
||||
)
|
||||
tasks = []
|
||||
for i in range(len(batch)):
|
||||
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
|
||||
tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))
|
||||
outputs = await asyncio.gather(*tasks)
|
||||
|
||||
output = self._postprocess(outputs)
|
||||
@ -311,10 +305,11 @@ class AgentLoopWorker:
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
agent_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
sampling_params: dict[str, Any],
|
||||
trajectory: dict[str, Any],
|
||||
*,
|
||||
agent_name: str,
|
||||
**kwargs,
|
||||
) -> _InternalAgentLoopOutput:
|
||||
with rollout_trace_attr(
|
||||
step=trajectory["step"],
|
||||
@ -334,7 +329,7 @@ class AgentLoopWorker:
|
||||
server_manager=self.server_manager,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
output = await agent_loop.run(messages, sampling_params)
|
||||
output = await agent_loop.run(sampling_params, **kwargs)
|
||||
|
||||
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
|
||||
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
|
||||
|
@ -32,7 +32,9 @@ class SingleTurnAgentLoop(AgentLoopBase):
|
||||
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
|
||||
self.response_length = self.config.actor_rollout_ref.rollout.response_length
|
||||
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||
messages = list(kwargs["raw_prompt"])
|
||||
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
prompt_ids = await self.loop.run_in_executor(
|
||||
|
@ -56,7 +56,8 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
|
||||
|
||||
@rollout_trace_op
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||
messages = list(kwargs["raw_prompt"])
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
prompt_ids = await self.loop.run_in_executor(
|
||||
|
Reference in New Issue
Block a user