mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[rollout] feat: Add gpt-oss tool parser to enable agent loop training for gpt-oss models (#3705)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Add gpt-oss tool parser to enable agent loop training for gpt-oss models ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Manually test offline. Let me know if we want to add unit tests. > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Hejian Sang <hsang@linkedin.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
34
tests/experimental/agent_loop/test_gpt_oss_tool_parser.py
Normal file
34
tests/experimental/agent_loop/test_gpt_oss_tool_parser.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from verl.experimental.agent_loop.tool_parser import GptOssToolParser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="local test only")
|
||||
async def test_gpt_oss_tool_parser():
|
||||
example_text = """
|
||||
<|start|>assistant<|channel|>commentary to=functions.get_current_weather \
|
||||
<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>
|
||||
<|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\
|
||||
{ "temperature": 20, "sunny": true }<|end|>"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
|
||||
response_ids = tokenizer.encode(example_text)
|
||||
tool_parser = GptOssToolParser(tokenizer)
|
||||
_, function_calls = await tool_parser.extract_tool_calls(response_ids)
|
||||
assert len(function_calls) == 1
|
||||
assert function_calls[0].name == "get_current_weather"
|
||||
assert function_calls[0].arguments == '{"location": "Tokyo"}'
|
@ -17,7 +17,7 @@ import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import regex as re
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
@ -81,7 +81,7 @@ class HermesToolParser(ToolParser):
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)
|
||||
|
||||
@rollout_trace_op
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
|
||||
@ -104,3 +104,58 @@ class HermesToolParser(ToolParser):
|
||||
content = self.tool_call_regex.sub("", text)
|
||||
|
||||
return content, function_calls
|
||||
|
||||
|
||||
@ToolParser.register("gpt-oss")
|
||||
class GptOssToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for gpt-oss model.
|
||||
Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer) -> None:
|
||||
super().__init__(tokenizer)
|
||||
# check https://cookbook.openai.com/articles/openai-harmony for more details.
|
||||
self.cot_pattern = regex.compile(
|
||||
r"<\|start\|>assistant<\|channel\|>analysis<\|message\|>.*?<\|end\|>", regex.DOTALL
|
||||
)
|
||||
# <|start|>assistant may be pre-appended in prompts, so we need to remove it.
|
||||
self.partial_cot_pattern = regex.compile(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", regex.DOTALL)
|
||||
self.tool_call_pattern = regex.compile(
|
||||
r"<\|start\|>assistant<\|channel\|>[^<]* to=functions\.([^<]+) "
|
||||
r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>",
|
||||
regex.DOTALL,
|
||||
)
|
||||
|
||||
@rollout_trace_op
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We need to keep special tokens for gpt-oss model for better tool call extraction.
|
||||
text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))
|
||||
# Need to remove padding tokens for better tool call extraction.
|
||||
text = text.replace(self.tokenizer.pad_token, "")
|
||||
# Need to reomve COT since COT may contain tool call tokens.But they are not valid tool calls.
|
||||
text = regex.sub(self.cot_pattern, "", text)
|
||||
text = regex.sub(self.partial_cot_pattern, "", text)
|
||||
|
||||
# check if there are tool calls in the text by re.findall
|
||||
matches = regex.findall(self.tool_call_pattern, text)
|
||||
if not matches:
|
||||
return text, []
|
||||
|
||||
function_calls = []
|
||||
for match in matches:
|
||||
try:
|
||||
name, arguments = match[0], match[1]
|
||||
# don't check if arguments is valid JSON and leave it to client
|
||||
function_calls.append(FunctionCall(name=name, arguments=arguments))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode tool call: {e}")
|
||||
|
||||
# remaing text exclude tool call tokens
|
||||
content = regex.sub(self.tool_call_pattern, "", text)
|
||||
|
||||
return content, function_calls
|
||||
|
Reference in New Issue
Block a user