mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Add accuracy reward (#4270)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
committed by
GitHub
parent
94aac4a101
commit
8e2d5516ca
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
||||||
|
|
||||||
## Format rewards
|
## accuracy_reward
|
||||||
|
|
||||||
### think_format_reward
|
[[autodoc]] rewards.accuracy_reward
|
||||||
|
|
||||||
|
## think_format_reward
|
||||||
|
|
||||||
[[autodoc]] rewards.think_format_reward
|
[[autodoc]] rewards.think_format_reward
|
||||||
|
|
||||||
## Other rewards
|
## get_soft_overlong_punishment
|
||||||
|
|
||||||
### get_soft_overlong_punishment
|
|
||||||
|
|
||||||
[[autodoc]] rewards.get_soft_overlong_punishment
|
[[autodoc]] rewards.get_soft_overlong_punishment
|
||||||
|
@ -70,8 +70,6 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
GRPOConfig,
|
GRPOConfig,
|
||||||
@ -83,7 +81,7 @@ from trl import (
|
|||||||
get_peft_config,
|
get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
|||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||||
|
|
||||||
################
|
|
||||||
# Reward Function for Training
|
|
||||||
################
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Training
|
# Training
|
||||||
################
|
################
|
||||||
|
@ -57,8 +57,6 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
GRPOConfig,
|
GRPOConfig,
|
||||||
@ -70,7 +68,7 @@ from trl import (
|
|||||||
get_peft_config,
|
get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -120,54 +118,6 @@ if __name__ == "__main__":
|
|||||||
train_dataset = train_dataset.remove_columns(["messages", "problem"])
|
train_dataset = train_dataset.remove_columns(["messages", "problem"])
|
||||||
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
|
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
|
||||||
|
|
||||||
################
|
|
||||||
# Reward Function for Training
|
|
||||||
################
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Training
|
# Training
|
||||||
################
|
################
|
||||||
|
@ -57,8 +57,6 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
GRPOConfig,
|
GRPOConfig,
|
||||||
@ -70,7 +68,7 @@ from trl import (
|
|||||||
get_peft_config,
|
get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -136,54 +134,6 @@ if __name__ == "__main__":
|
|||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||||
|
|
||||||
################
|
|
||||||
# Reward Function for Training
|
|
||||||
################
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Training
|
# Training
|
||||||
################
|
################
|
||||||
|
@ -87,8 +87,6 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
from transformers import AutoConfig, AutoProcessor, GenerationConfig
|
from transformers import AutoConfig, AutoProcessor, GenerationConfig
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
@ -102,7 +100,7 @@ from trl import (
|
|||||||
get_peft_config,
|
get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -192,54 +190,6 @@ if __name__ == "__main__":
|
|||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||||
|
|
||||||
################
|
|
||||||
# Reward Function for Training (same as GRPO VLM)
|
|
||||||
################
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Training
|
# Training
|
||||||
################
|
################
|
||||||
|
@ -33,12 +33,10 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
|
|
||||||
from trl import RLOOConfig, RLOOTrainer
|
from trl import RLOOConfig, RLOOTrainer
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -67,52 +65,6 @@ def main():
|
|||||||
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||||
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||||
|
|
||||||
# Reward function for training
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
training_args = RLOOConfig(
|
training_args = RLOOConfig(
|
||||||
output_dir="Qwen3-0.6B-RLOO",
|
output_dir="Qwen3-0.6B-RLOO",
|
||||||
|
@ -70,8 +70,6 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@ -83,7 +81,7 @@ from trl import (
|
|||||||
get_peft_config,
|
get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import think_format_reward
|
from trl.rewards import accuracy_reward, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
# Enable logging in a Hugging Face Space
|
# Enable logging in a Hugging Face Space
|
||||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
|||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||||
|
|
||||||
################
|
|
||||||
# Reward Function for Training
|
|
||||||
################
|
|
||||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
|
||||||
"""Reward function that checks if the completion matches the ground truth.
|
|
||||||
- If both gold and prediction are parseable → use math verification.
|
|
||||||
- If not parseable → compare as normalized text.
|
|
||||||
"""
|
|
||||||
rewards = []
|
|
||||||
contents = [completion[0]["content"] for completion in completions]
|
|
||||||
for content, sol in zip(contents, solution):
|
|
||||||
try:
|
|
||||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
|
||||||
except Exception:
|
|
||||||
gold_parsed = []
|
|
||||||
|
|
||||||
if len(gold_parsed) != 0:
|
|
||||||
# Try parsing predicted answer too
|
|
||||||
try:
|
|
||||||
answer_parsed = parse(
|
|
||||||
content,
|
|
||||||
extraction_config=[
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
try_extract_without_anchor=False,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
extraction_mode="first_match",
|
|
||||||
)
|
|
||||||
reward = float(verify(gold_parsed, answer_parsed))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
|
||||||
reward = None
|
|
||||||
else:
|
|
||||||
# fallback to text match
|
|
||||||
reward = float(content.strip().lower() == sol.strip().lower())
|
|
||||||
|
|
||||||
rewards.append(reward)
|
|
||||||
|
|
||||||
return rewards
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Training
|
# Training
|
||||||
################
|
################
|
||||||
|
@ -89,6 +89,9 @@ vlm = [
|
|||||||
"torchvision",
|
"torchvision",
|
||||||
"num2words==0.5.14"
|
"num2words==0.5.14"
|
||||||
]
|
]
|
||||||
|
math_verify = [
|
||||||
|
"math-verify>=0.5.2",
|
||||||
|
]
|
||||||
dev = [
|
dev = [
|
||||||
# bco
|
# bco
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -13,9 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||||
|
|
||||||
from .testing_utils import TrlTestCase
|
from .testing_utils import TrlTestCase, require_math_latex
|
||||||
|
|
||||||
|
|
||||||
class TestThinkFormatReward(TrlTestCase):
|
class TestThinkFormatReward(TrlTestCase):
|
||||||
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
|
|||||||
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
||||||
rewards = reward_fn(completion_ids)
|
rewards = reward_fn(completion_ids)
|
||||||
assert round(abs(rewards[0] - -0.5), 4) == 0
|
assert round(abs(rewards[0] - -0.5), 4) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccuracyReward:
|
||||||
|
@require_math_latex
|
||||||
|
def test_accuracy_reward_correct_answer(self):
|
||||||
|
"""Test accuracy_reward with a correct answer."""
|
||||||
|
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||||
|
solution = [r"\frac{63}{400}", "63/400"]
|
||||||
|
rewards = accuracy_reward(completion, solution)
|
||||||
|
assert rewards[0] == 1.0
|
||||||
|
assert rewards[1] == 1.0
|
||||||
|
|
||||||
|
@require_math_latex
|
||||||
|
def test_accuracy_reward_wrong_answer(self):
|
||||||
|
"""Test accuracy_reward with an incorrect answer."""
|
||||||
|
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||||
|
solution = [r"\frac{63}{400}"]
|
||||||
|
rewards = accuracy_reward(completion, solution)
|
||||||
|
assert rewards[0] == 0.0
|
||||||
|
|
||||||
|
@require_math_latex
|
||||||
|
def test_accuracy_reward_wrong_answer_no_latex(self):
|
||||||
|
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
|
||||||
|
completion = [[{"content": r"\boxed{3}"}]]
|
||||||
|
solution = ["6"]
|
||||||
|
rewards = accuracy_reward(completion, solution)
|
||||||
|
assert rewards[0] == 0.0
|
||||||
|
|
||||||
|
@require_math_latex
|
||||||
|
def test_accuracy_reward_unparseable_gold(self):
|
||||||
|
"""Test accuracy_reward with an unparseable gold solution."""
|
||||||
|
completion = [
|
||||||
|
[{"content": "Answer is forty two."}],
|
||||||
|
[{"content": "Some other content."}],
|
||||||
|
[{"content": r"Answer is \boxed{42}."}],
|
||||||
|
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
|
||||||
|
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
|
||||||
|
[{"content": r"Answer is \boxed{42}."}],
|
||||||
|
[{"content": r"Answer is \boxed{42.3456}."}],
|
||||||
|
]
|
||||||
|
solution = [
|
||||||
|
"Answer is forty two.",
|
||||||
|
"Answer is forty three.",
|
||||||
|
"Answer is 42.0", # Decimal point
|
||||||
|
"Answer is 42 43 okay?", # Extra space
|
||||||
|
"Answer is 42",
|
||||||
|
r"Answer is \n\boxed{42}", # Newline in gold solution
|
||||||
|
"Answer is 42.34560", # Extra trailing zero
|
||||||
|
]
|
||||||
|
rewards = accuracy_reward(completion, solution)
|
||||||
|
assert rewards[0] == 1.0 # Should revert to exact text match
|
||||||
|
assert rewards[1] == 0.0
|
||||||
|
assert rewards[2] == 1.0
|
||||||
|
assert rewards[3] == 1.0
|
||||||
|
assert rewards[4] == 1.0
|
||||||
|
assert rewards[5] == 1.0
|
||||||
|
assert rewards[6] == 1.0 # Should ignore trailing zeros
|
||||||
|
@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
|
|||||||
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
|
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
|
||||||
|
|
||||||
from trl import BaseBinaryJudge, BasePairwiseJudge
|
from trl import BaseBinaryJudge, BasePairwiseJudge
|
||||||
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
|
from trl.import_utils import (
|
||||||
|
is_joblib_available,
|
||||||
|
is_llm_blender_available,
|
||||||
|
is_math_verify_available,
|
||||||
|
is_mergekit_available,
|
||||||
|
is_vllm_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
||||||
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
||||||
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
|
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
|
||||||
|
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
|
||||||
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
||||||
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
||||||
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
||||||
|
@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
|
|||||||
_joblib_available = _is_package_available("joblib")
|
_joblib_available = _is_package_available("joblib")
|
||||||
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||||
_llm_blender_available = _is_package_available("llm_blender")
|
_llm_blender_available = _is_package_available("llm_blender")
|
||||||
|
_math_verify_available = _is_package_available("math_verify")
|
||||||
_mergekit_available = _is_package_available("mergekit")
|
_mergekit_available = _is_package_available("mergekit")
|
||||||
_pydantic_available = _is_package_available("pydantic")
|
_pydantic_available = _is_package_available("pydantic")
|
||||||
_requests_available = _is_package_available("requests")
|
_requests_available = _is_package_available("requests")
|
||||||
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
|
|||||||
return _llm_blender_available
|
return _llm_blender_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_math_verify_available() -> bool:
|
||||||
|
return _math_verify_available
|
||||||
|
|
||||||
|
|
||||||
def is_mergekit_available() -> bool:
|
def is_mergekit_available() -> bool:
|
||||||
return _mergekit_available
|
return _mergekit_available
|
||||||
|
|
||||||
|
@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
|
"accuracy_rewards": ["accuracy_reward"],
|
||||||
"format_rewards": ["think_format_reward"],
|
"format_rewards": ["think_format_reward"],
|
||||||
"other_rewards": ["get_soft_overlong_punishment"],
|
"other_rewards": ["get_soft_overlong_punishment"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from .accuracy_rewards import accuracy_reward
|
||||||
from .format_rewards import think_format_reward
|
from .format_rewards import think_format_reward
|
||||||
from .other_rewards import get_soft_overlong_punishment
|
from .other_rewards import get_soft_overlong_punishment
|
||||||
|
|
||||||
|
93
trl/rewards/accuracy_rewards.py
Normal file
93
trl/rewards/accuracy_rewards.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from trl.import_utils import is_math_verify_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_math_verify_available():
|
||||||
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
|
from math_verify import LatexExtractionConfig, parse, verify
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||||
|
r"""
|
||||||
|
Reward function that checks if the completion is the same as the ground truth.
|
||||||
|
- If both gold and prediction are parseable → use math verification.
|
||||||
|
- If not parseable → compare as normalized text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completions (`list[list[dict[str, str]]]`):
|
||||||
|
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
|
||||||
|
containing the key `"content"` with the value being the text of the completion.
|
||||||
|
solution: (`list[str]`):
|
||||||
|
List of the raw-text solutions to the questions/problems/prompts.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments. This function does not use them, but they are required in the function
|
||||||
|
signature to ensure compatibility with trainers like [`GRPOTrainer`].
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from trl.rewards import accuracy_reward
|
||||||
|
|
||||||
|
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
|
||||||
|
>>> completion = [
|
||||||
|
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
|
||||||
|
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
|
||||||
|
... ]
|
||||||
|
>>> accuracy_reward(completion, solution)
|
||||||
|
[1.0, 0.0]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not is_math_verify_available():
|
||||||
|
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
|
||||||
|
|
||||||
|
contents = [completion[0]["content"] for completion in completions]
|
||||||
|
rewards = []
|
||||||
|
for content, sol in zip(contents, solution):
|
||||||
|
gold_parsed = parse(
|
||||||
|
sol,
|
||||||
|
extraction_mode="first_match",
|
||||||
|
)
|
||||||
|
if len(gold_parsed) != 0:
|
||||||
|
# We require the answer to be provided in correct latex (no malformed operators)
|
||||||
|
answer_parsed = parse(
|
||||||
|
content,
|
||||||
|
extraction_config=[
|
||||||
|
LatexExtractionConfig(
|
||||||
|
normalization_config=NormalizationConfig(
|
||||||
|
nits=False,
|
||||||
|
malformed_operators=False,
|
||||||
|
basic_latex=True,
|
||||||
|
boxed="all",
|
||||||
|
units=True,
|
||||||
|
),
|
||||||
|
# Ensures that boxed is tried first
|
||||||
|
boxed_match_priority=0,
|
||||||
|
try_extract_without_anchor=False,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
extraction_mode="first_match",
|
||||||
|
)
|
||||||
|
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||||
|
try:
|
||||||
|
reward = float(verify(gold_parsed, answer_parsed))
|
||||||
|
except Exception:
|
||||||
|
reward = None
|
||||||
|
else:
|
||||||
|
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||||
|
reward = float(content.strip().lower() == sol.strip().lower())
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
return rewards
|
@ -41,7 +41,7 @@ from trl import (
|
|||||||
get_dataset,
|
get_dataset,
|
||||||
get_peft_config,
|
get_peft_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
|||||||
|
|
||||||
|
|
||||||
reward_funcs_registry = {
|
reward_funcs_registry = {
|
||||||
|
"accuracy_reward": accuracy_reward,
|
||||||
"think_format_reward": think_format_reward,
|
"think_format_reward": think_format_reward,
|
||||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||||
}
|
}
|
||||||
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
|
|||||||
reward_funcs (`list[str]`, *optional*):
|
reward_funcs (`list[str]`, *optional*):
|
||||||
Reward functions to use. Supported values are:
|
Reward functions to use. Supported values are:
|
||||||
|
|
||||||
|
- `"accuracy_reward"`
|
||||||
- `"think_format_reward"`
|
- `"think_format_reward"`
|
||||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||||
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
|
|||||||
reward_funcs: Optional[list[str]] = field(
|
reward_funcs: Optional[list[str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||||
},
|
},
|
||||||
|
@ -41,7 +41,7 @@ from trl import (
|
|||||||
get_dataset,
|
get_dataset,
|
||||||
get_peft_config,
|
get_peft_config,
|
||||||
)
|
)
|
||||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
|||||||
|
|
||||||
|
|
||||||
reward_funcs_registry = {
|
reward_funcs_registry = {
|
||||||
|
"accuracy_reward": accuracy_reward,
|
||||||
"think_format_reward": think_format_reward,
|
"think_format_reward": think_format_reward,
|
||||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||||
}
|
}
|
||||||
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
|
|||||||
reward_funcs (`list[str]`, *optional*):
|
reward_funcs (`list[str]`, *optional*):
|
||||||
Reward functions to use. Supported values are:
|
Reward functions to use. Supported values are:
|
||||||
|
|
||||||
|
- `"accuracy_reward"`
|
||||||
- `"think_format_reward"`
|
- `"think_format_reward"`
|
||||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||||
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
|
|||||||
reward_funcs: Optional[list[str]] = field(
|
reward_funcs: Optional[list[str]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||||
},
|
},
|
||||||
|
Reference in New Issue
Block a user