[doc] fix: Fix the role assignment error in the interaction demo file and doc. (#2476)

### What does this PR do?

Fix the role assignment error in the interaction demo file
verl/interactions/gsm8k_interaction.py and doc. The assistant is
expected to solve problems, while users provide problems and feedback
within the messages list.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

Update tests/interactions/test_gsm8k_interaction.py.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
Qiao
2025-08-03 08:04:15 +08:00
committed by GitHub
parent a24241092d
commit 2fdfbdcba6
3 changed files with 21 additions and 19 deletions

View File

@ -180,7 +180,7 @@ The GSM8K interaction demonstrates a complete implementation for math problem-so
return instance_id return instance_id
async def generate_response(self, instance_id, messages, **kwargs): async def generate_response(self, instance_id, messages, **kwargs):
# Extract last user message content # Extract last assistant message content
content = "" content = ""
for item in reversed(messages): for item in reversed(messages):
if item.get("role") == "assistant": if item.get("role") == "assistant":
@ -299,7 +299,8 @@ Comprehensive testing is essential for interaction systems:
# Test complete workflow # Test complete workflow
instance_id = await interaction.start_interaction(ground_truth="expected_answer") instance_id = await interaction.start_interaction(ground_truth="expected_answer")
messages = [{"role": "user", "content": "user_content"}, {"role": "assistant", "content": "assistant_response"}]
messages = [{"role": "user", "content": "user_content"}, {"role": "assistant", "content": "assistant_content"}]
should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages) should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages)
assert should_terminate in [True, False] assert should_terminate in [True, False]

View File

@ -80,7 +80,7 @@ class TestGsm8kInteraction:
# Setup instance # Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "#### 42"}] messages = [{"role": "assistant", "content": "#### 42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -102,7 +102,7 @@ class TestGsm8kInteraction:
# Setup instance # Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "42"}] messages = [{"role": "assistant", "content": "42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -123,7 +123,7 @@ class TestGsm8kInteraction:
# Setup instance # Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "24"}] messages = [{"role": "assistant", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -137,7 +137,7 @@ class TestGsm8kInteraction:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_response_multiple_messages(self): async def test_generate_response_multiple_messages(self):
"""Test generate_response with multiple messages (should use last user message).""" """Test generate_response with multiple messages (should use last assistant message)."""
instance_id = "test_instance" instance_id = "test_instance"
ground_truth = "42" ground_truth = "42"
@ -146,8 +146,9 @@ class TestGsm8kInteraction:
messages = [ messages = [
{"role": "user", "content": "What is 2+2?"}, {"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "Let me think about this..."}, {"role": "assistant", "content": "### 4"},
{"role": "user", "content": "#### 42"}, {"role": "user", "content": "What is 40+2?"},
{"role": "assistant", "content": "#### 42"},
] ]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
@ -160,15 +161,15 @@ class TestGsm8kInteraction:
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_response_no_user_message(self): async def test_generate_response_no_assistant_message(self):
"""Test generate_response with no user messages.""" """Test generate_response with no assistant messages."""
instance_id = "test_instance" instance_id = "test_instance"
ground_truth = "42" ground_truth = "42"
# Setup instance # Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "assistant", "content": "Hello!"}] messages = [{"role": "user", "content": "Hello!"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -262,7 +263,7 @@ class TestGsm8kInteraction:
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
# Generate response with correct answer # Generate response with correct answer
messages = [{"role": "user", "content": "42"}] messages = [{"role": "assistant", "content": "42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -285,7 +286,7 @@ class TestGsm8kInteraction:
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
# Generate response with incorrect answer # Generate response with incorrect answer
messages = [{"role": "user", "content": "24"}] messages = [{"role": "assistant", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -296,8 +297,8 @@ class TestGsm8kInteraction:
assert reward == 0.0 assert reward == 0.0
# Continue with another attempt # Continue with another attempt
messages.append({"role": "assistant", "content": response}) messages.append({"role": "user", "content": response})
messages.append({"role": "user", "content": "42"}) messages.append({"role": "assistant", "content": "42"})
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response( should_terminate, response, reward, metadata = await self.interaction.generate_response(
@ -326,8 +327,8 @@ class TestGsm8kInteraction:
assert instance_id_2 in self.interaction._instance_dict assert instance_id_2 in self.interaction._instance_dict
# Test responses for both instances # Test responses for both instances
messages_1 = [{"role": "user", "content": "42"}] messages_1 = [{"role": "assistant", "content": "42"}]
messages_2 = [{"role": "user", "content": "24"}] messages_2 = [{"role": "assistant", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]): with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]):
should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1) should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)
@ -374,7 +375,7 @@ class TestGsm8kInteraction:
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [ messages = [
{"role": "user"} # Missing content field {"role": "assistant"} # Missing content field
] ]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):

View File

@ -31,7 +31,7 @@ class Gsm8kInteraction(BaseInteraction):
"""A demo interaction for calculating the reward of gsm8k. """A demo interaction for calculating the reward of gsm8k.
- `start_interaction`: start a interaction instance for a trajectory. - `start_interaction`: start a interaction instance for a trajectory.
- `generate_response`: generate the response of the user. - `generate_response`: generate the response of the assistant.
- `calculate_score`: calculate the score of the interaction. - `calculate_score`: calculate the score of the interaction.
- `finalize_interaction`: finalize the interaction instance. - `finalize_interaction`: finalize the interaction instance.
""" """