mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
Add accuracy reward (#4270)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
committed by
GitHub
parent
94aac4a101
commit
8e2d5516ca
@ -2,14 +2,14 @@
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
## accuracy_reward
|
||||
|
||||
### think_format_reward
|
||||
[[autodoc]] rewards.accuracy_reward
|
||||
|
||||
## think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
||||
|
||||
## Other rewards
|
||||
|
||||
### get_soft_overlong_punishment
|
||||
## get_soft_overlong_punishment
|
||||
|
||||
[[autodoc]] rewards.get_soft_overlong_punishment
|
||||
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -120,54 +118,6 @@ if __name__ == "__main__":
|
||||
train_dataset = train_dataset.remove_columns(["messages", "problem"])
|
||||
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -136,54 +134,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -87,8 +87,6 @@ import os
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from transformers import AutoConfig, AutoProcessor, GenerationConfig
|
||||
|
||||
from trl import (
|
||||
@ -102,7 +100,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -192,54 +190,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training (same as GRPO VLM)
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -33,12 +33,10 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from peft import LoraConfig
|
||||
|
||||
from trl import RLOOConfig, RLOOTrainer
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -67,52 +65,6 @@ def main():
|
||||
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
|
||||
# Reward function for training
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
# Training
|
||||
training_args = RLOOConfig(
|
||||
output_dir="Qwen3-0.6B-RLOO",
|
||||
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -89,6 +89,9 @@ vlm = [
|
||||
"torchvision",
|
||||
"num2words==0.5.14"
|
||||
]
|
||||
math_verify = [
|
||||
"math-verify>=0.5.2",
|
||||
]
|
||||
dev = [
|
||||
# bco
|
||||
"scikit-learn",
|
||||
|
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
from .testing_utils import TrlTestCase, require_math_latex
|
||||
|
||||
|
||||
class TestThinkFormatReward(TrlTestCase):
|
||||
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
|
||||
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
||||
rewards = reward_fn(completion_ids)
|
||||
assert round(abs(rewards[0] - -0.5), 4) == 0
|
||||
|
||||
|
||||
class TestAccuracyReward:
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_correct_answer(self):
|
||||
"""Test accuracy_reward with a correct answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}", "63/400"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0
|
||||
assert rewards[1] == 1.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer(self):
|
||||
"""Test accuracy_reward with an incorrect answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer_no_latex(self):
|
||||
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
|
||||
completion = [[{"content": r"\boxed{3}"}]]
|
||||
solution = ["6"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_unparseable_gold(self):
|
||||
"""Test accuracy_reward with an unparseable gold solution."""
|
||||
completion = [
|
||||
[{"content": "Answer is forty two."}],
|
||||
[{"content": "Some other content."}],
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
|
||||
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{42.3456}."}],
|
||||
]
|
||||
solution = [
|
||||
"Answer is forty two.",
|
||||
"Answer is forty three.",
|
||||
"Answer is 42.0", # Decimal point
|
||||
"Answer is 42 43 okay?", # Extra space
|
||||
"Answer is 42",
|
||||
r"Answer is \n\boxed{42}", # Newline in gold solution
|
||||
"Answer is 42.34560", # Extra trailing zero
|
||||
]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0 # Should revert to exact text match
|
||||
assert rewards[1] == 0.0
|
||||
assert rewards[2] == 1.0
|
||||
assert rewards[3] == 1.0
|
||||
assert rewards[4] == 1.0
|
||||
assert rewards[5] == 1.0
|
||||
assert rewards[6] == 1.0 # Should ignore trailing zeros
|
||||
|
@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
|
||||
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
|
||||
|
||||
from trl import BaseBinaryJudge, BasePairwiseJudge
|
||||
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
|
||||
from trl.import_utils import (
|
||||
is_joblib_available,
|
||||
is_llm_blender_available,
|
||||
is_math_verify_available,
|
||||
is_mergekit_available,
|
||||
is_vllm_available,
|
||||
)
|
||||
|
||||
|
||||
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
||||
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
||||
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
|
||||
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
|
||||
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
||||
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
||||
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
||||
|
@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||
_llm_blender_available = _is_package_available("llm_blender")
|
||||
_math_verify_available = _is_package_available("math_verify")
|
||||
_mergekit_available = _is_package_available("mergekit")
|
||||
_pydantic_available = _is_package_available("pydantic")
|
||||
_requests_available = _is_package_available("requests")
|
||||
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
|
||||
return _llm_blender_available
|
||||
|
||||
|
||||
def is_math_verify_available() -> bool:
|
||||
return _math_verify_available
|
||||
|
||||
|
||||
def is_mergekit_available() -> bool:
|
||||
return _mergekit_available
|
||||
|
||||
|
@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"accuracy_rewards": ["accuracy_reward"],
|
||||
"format_rewards": ["think_format_reward"],
|
||||
"other_rewards": ["get_soft_overlong_punishment"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accuracy_rewards import accuracy_reward
|
||||
from .format_rewards import think_format_reward
|
||||
from .other_rewards import get_soft_overlong_punishment
|
||||
|
||||
|
93
trl/rewards/accuracy_rewards.py
Normal file
93
trl/rewards/accuracy_rewards.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from trl.import_utils import is_math_verify_available
|
||||
|
||||
|
||||
if is_math_verify_available():
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
|
||||
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||
r"""
|
||||
Reward function that checks if the completion is the same as the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
|
||||
Args:
|
||||
completions (`list[list[dict[str, str]]]`):
|
||||
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
|
||||
containing the key `"content"` with the value being the text of the completion.
|
||||
solution: (`list[str]`):
|
||||
List of the raw-text solutions to the questions/problems/prompts.
|
||||
**kwargs:
|
||||
Additional keyword arguments. This function does not use them, but they are required in the function
|
||||
signature to ensure compatibility with trainers like [`GRPOTrainer`].
|
||||
Example:
|
||||
```python
|
||||
>>> from trl.rewards import accuracy_reward
|
||||
|
||||
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
|
||||
>>> completion = [
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
|
||||
... ]
|
||||
>>> accuracy_reward(completion, solution)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
"""
|
||||
if not is_math_verify_available():
|
||||
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception:
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
Reference in New Issue
Block a user