[rollout] feat: Add gpt-oss tool parser to enable agent loop training for gpt-oss models (#3705)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Add gpt-oss tool parser to enable agent loop training for gpt-oss models
### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
Manually test offline. Let me know if we want to add unit tests. 
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: Hejian Sang <hsang@linkedin.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
HEJIAN SANG
2025-10-10 20:53:10 -07:00
committed by GitHub
parent d87602432c
commit e960fbaeab
2 changed files with 91 additions and 2 deletions

View File

@ -0,0 +1,34 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from transformers import AutoTokenizer
from verl.experimental.agent_loop.tool_parser import GptOssToolParser
@pytest.mark.asyncio
@pytest.mark.skip(reason="local test only")
async def test_gpt_oss_tool_parser():
example_text = """
<|start|>assistant<|channel|>commentary to=functions.get_current_weather \
<|constrain|>json<|message|>{"location": "Tokyo"}<|call|>
<|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\
{ "temperature": 20, "sunny": true }<|end|>"""
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
response_ids = tokenizer.encode(example_text)
tool_parser = GptOssToolParser(tokenizer)
_, function_calls = await tool_parser.extract_tool_calls(response_ids)
assert len(function_calls) == 1
assert function_calls[0].name == "get_current_weather"
assert function_calls[0].arguments == '{"location": "Tokyo"}'

View File

@ -17,7 +17,7 @@ import logging
import os
from abc import ABC, abstractmethod
import regex as re
import regex
from pydantic import BaseModel
from verl.utils.rollout_trace import rollout_trace_op
@ -81,7 +81,7 @@ class HermesToolParser(ToolParser):
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)
@rollout_trace_op
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
@ -104,3 +104,58 @@ class HermesToolParser(ToolParser):
content = self.tool_call_regex.sub("", text)
return content, function_calls
@ToolParser.register("gpt-oss")
class GptOssToolParser(ToolParser):
"""
Tool parser for gpt-oss model.
Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py
Args:
tokenizer: The tokenizer to use.
"""
def __init__(self, tokenizer) -> None:
super().__init__(tokenizer)
# check https://cookbook.openai.com/articles/openai-harmony for more details.
self.cot_pattern = regex.compile(
r"<\|start\|>assistant<\|channel\|>analysis<\|message\|>.*?<\|end\|>", regex.DOTALL
)
# <|start|>assistant may be pre-appended in prompts, so we need to remove it.
self.partial_cot_pattern = regex.compile(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", regex.DOTALL)
self.tool_call_pattern = regex.compile(
r"<\|start\|>assistant<\|channel\|>[^<]* to=functions\.([^<]+) "
r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>",
regex.DOTALL,
)
@rollout_trace_op
async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]:
loop = asyncio.get_running_loop()
# We need to keep special tokens for gpt-oss model for better tool call extraction.
text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))
# Need to remove padding tokens for better tool call extraction.
text = text.replace(self.tokenizer.pad_token, "")
# Need to reomve COT since COT may contain tool call tokens.But they are not valid tool calls.
text = regex.sub(self.cot_pattern, "", text)
text = regex.sub(self.partial_cot_pattern, "", text)
# check if there are tool calls in the text by re.findall
matches = regex.findall(self.tool_call_pattern, text)
if not matches:
return text, []
function_calls = []
for match in matches:
try:
name, arguments = match[0], match[1]
# don't check if arguments is valid JSON and leave it to client
function_calls.append(FunctionCall(name=name, arguments=arguments))
except Exception as e:
logger.error(f"Failed to decode tool call: {e}")
# remaing text exclude tool call tokens
content = regex.sub(self.tool_call_pattern, "", text)
return content, function_calls