mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[doc] fix: Fix the role assignment error in the interaction demo file and doc. (#2476)
### What does this PR do? Fix the role assignment error in the interaction demo file verl/interactions/gsm8k_interaction.py and doc. The assistant is expected to solve problems, while users provide problems and feedback within the messages list. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Update tests/interactions/test_gsm8k_interaction.py. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Co-authored-by: H <linhaibin.eric@gmail.com>
This commit is contained in:
@ -180,7 +180,7 @@ The GSM8K interaction demonstrates a complete implementation for math problem-so
|
|||||||
return instance_id
|
return instance_id
|
||||||
|
|
||||||
async def generate_response(self, instance_id, messages, **kwargs):
|
async def generate_response(self, instance_id, messages, **kwargs):
|
||||||
# Extract last user message content
|
# Extract last assistant message content
|
||||||
content = ""
|
content = ""
|
||||||
for item in reversed(messages):
|
for item in reversed(messages):
|
||||||
if item.get("role") == "assistant":
|
if item.get("role") == "assistant":
|
||||||
@ -299,7 +299,8 @@ Comprehensive testing is essential for interaction systems:
|
|||||||
# Test complete workflow
|
# Test complete workflow
|
||||||
instance_id = await interaction.start_interaction(ground_truth="expected_answer")
|
instance_id = await interaction.start_interaction(ground_truth="expected_answer")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "user_content"}, {"role": "assistant", "content": "assistant_response"}]
|
|
||||||
|
messages = [{"role": "user", "content": "user_content"}, {"role": "assistant", "content": "assistant_content"}]
|
||||||
should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages)
|
should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages)
|
||||||
|
|
||||||
assert should_terminate in [True, False]
|
assert should_terminate in [True, False]
|
||||||
|
@ -80,7 +80,7 @@ class TestGsm8kInteraction:
|
|||||||
# Setup instance
|
# Setup instance
|
||||||
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "#### 42"}]
|
messages = [{"role": "assistant", "content": "#### 42"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -102,7 +102,7 @@ class TestGsm8kInteraction:
|
|||||||
# Setup instance
|
# Setup instance
|
||||||
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "42"}]
|
messages = [{"role": "assistant", "content": "42"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -123,7 +123,7 @@ class TestGsm8kInteraction:
|
|||||||
# Setup instance
|
# Setup instance
|
||||||
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "24"}]
|
messages = [{"role": "assistant", "content": "24"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -137,7 +137,7 @@ class TestGsm8kInteraction:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_response_multiple_messages(self):
|
async def test_generate_response_multiple_messages(self):
|
||||||
"""Test generate_response with multiple messages (should use last user message)."""
|
"""Test generate_response with multiple messages (should use last assistant message)."""
|
||||||
instance_id = "test_instance"
|
instance_id = "test_instance"
|
||||||
ground_truth = "42"
|
ground_truth = "42"
|
||||||
|
|
||||||
@ -146,8 +146,9 @@ class TestGsm8kInteraction:
|
|||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
{"role": "assistant", "content": "Let me think about this..."},
|
{"role": "assistant", "content": "### 4"},
|
||||||
{"role": "user", "content": "#### 42"},
|
{"role": "user", "content": "What is 40+2?"},
|
||||||
|
{"role": "assistant", "content": "#### 42"},
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
||||||
@ -160,15 +161,15 @@ class TestGsm8kInteraction:
|
|||||||
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
|
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_response_no_user_message(self):
|
async def test_generate_response_no_assistant_message(self):
|
||||||
"""Test generate_response with no user messages."""
|
"""Test generate_response with no assistant messages."""
|
||||||
instance_id = "test_instance"
|
instance_id = "test_instance"
|
||||||
ground_truth = "42"
|
ground_truth = "42"
|
||||||
|
|
||||||
# Setup instance
|
# Setup instance
|
||||||
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
messages = [{"role": "assistant", "content": "Hello!"}]
|
messages = [{"role": "user", "content": "Hello!"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -262,7 +263,7 @@ class TestGsm8kInteraction:
|
|||||||
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
|
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
|
||||||
|
|
||||||
# Generate response with correct answer
|
# Generate response with correct answer
|
||||||
messages = [{"role": "user", "content": "42"}]
|
messages = [{"role": "assistant", "content": "42"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -285,7 +286,7 @@ class TestGsm8kInteraction:
|
|||||||
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
|
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
|
||||||
|
|
||||||
# Generate response with incorrect answer
|
# Generate response with incorrect answer
|
||||||
messages = [{"role": "user", "content": "24"}]
|
messages = [{"role": "assistant", "content": "24"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -296,8 +297,8 @@ class TestGsm8kInteraction:
|
|||||||
assert reward == 0.0
|
assert reward == 0.0
|
||||||
|
|
||||||
# Continue with another attempt
|
# Continue with another attempt
|
||||||
messages.append({"role": "assistant", "content": response})
|
messages.append({"role": "user", "content": response})
|
||||||
messages.append({"role": "user", "content": "42"})
|
messages.append({"role": "assistant", "content": "42"})
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
|
||||||
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
should_terminate, response, reward, metadata = await self.interaction.generate_response(
|
||||||
@ -326,8 +327,8 @@ class TestGsm8kInteraction:
|
|||||||
assert instance_id_2 in self.interaction._instance_dict
|
assert instance_id_2 in self.interaction._instance_dict
|
||||||
|
|
||||||
# Test responses for both instances
|
# Test responses for both instances
|
||||||
messages_1 = [{"role": "user", "content": "42"}]
|
messages_1 = [{"role": "assistant", "content": "42"}]
|
||||||
messages_2 = [{"role": "user", "content": "24"}]
|
messages_2 = [{"role": "assistant", "content": "24"}]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]):
|
||||||
should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)
|
should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)
|
||||||
@ -374,7 +375,7 @@ class TestGsm8kInteraction:
|
|||||||
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user"} # Missing content field
|
{"role": "assistant"} # Missing content field
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
|
||||||
|
@ -31,7 +31,7 @@ class Gsm8kInteraction(BaseInteraction):
|
|||||||
"""A demo interaction for calculating the reward of gsm8k.
|
"""A demo interaction for calculating the reward of gsm8k.
|
||||||
|
|
||||||
- `start_interaction`: start a interaction instance for a trajectory.
|
- `start_interaction`: start a interaction instance for a trajectory.
|
||||||
- `generate_response`: generate the response of the user.
|
- `generate_response`: generate the response of the assistant.
|
||||||
- `calculate_score`: calculate the score of the interaction.
|
- `calculate_score`: calculate the score of the interaction.
|
||||||
- `finalize_interaction`: finalize the interaction instance.
|
- `finalize_interaction`: finalize the interaction instance.
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user