Add accuracy reward (#4270)

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
Pramodith Ballapuram
2025-10-16 01:01:07 +01:00
committed by GitHub
parent 94aac4a101
commit 8e2d5516ca
15 changed files with 189 additions and 316 deletions

View File

@ -2,14 +2,14 @@
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`]. This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
## Format rewards ## accuracy_reward
### think_format_reward [[autodoc]] rewards.accuracy_reward
## think_format_reward
[[autodoc]] rewards.think_format_reward [[autodoc]] rewards.think_format_reward
## Other rewards ## get_soft_overlong_punishment
### get_soft_overlong_punishment
[[autodoc]] rewards.get_soft_overlong_punishment [[autodoc]] rewards.get_soft_overlong_punishment

View File

@ -70,8 +70,6 @@ import os
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import ( from trl import (
GRPOConfig, GRPOConfig,
@ -83,7 +81,7 @@ from trl import (
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -149,54 +147,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################ ################
# Training # Training
################ ################

View File

@ -57,8 +57,6 @@ import os
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import ( from trl import (
GRPOConfig, GRPOConfig,
@ -70,7 +68,7 @@ from trl import (
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -120,54 +118,6 @@ if __name__ == "__main__":
train_dataset = train_dataset.remove_columns(["messages", "problem"]) train_dataset = train_dataset.remove_columns(["messages", "problem"])
eval_dataset = eval_dataset.remove_columns(["messages", "problem"]) eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################ ################
# Training # Training
################ ################

View File

@ -57,8 +57,6 @@ import os
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import ( from trl import (
GRPOConfig, GRPOConfig,
@ -70,7 +68,7 @@ from trl import (
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -136,54 +134,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################ ################
# Training # Training
################ ################

View File

@ -87,8 +87,6 @@ import os
import torch import torch
import transformers import transformers
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from transformers import AutoConfig, AutoProcessor, GenerationConfig from transformers import AutoConfig, AutoProcessor, GenerationConfig
from trl import ( from trl import (
@ -102,7 +100,7 @@ from trl import (
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -192,54 +190,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training (same as GRPO VLM)
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################ ################
# Training # Training
################ ################

View File

@ -33,12 +33,10 @@ import os
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from peft import LoraConfig from peft import LoraConfig
from trl import RLOOConfig, RLOOTrainer from trl import RLOOConfig, RLOOTrainer
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -67,52 +65,6 @@ def main():
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"]) train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"]) eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
# Reward function for training
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
# Training # Training
training_args = RLOOConfig( training_args = RLOOConfig(
output_dir="Qwen3-0.6B-RLOO", output_dir="Qwen3-0.6B-RLOO",

View File

@ -70,8 +70,6 @@ import os
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import ( from trl import (
ModelConfig, ModelConfig,
@ -83,7 +81,7 @@ from trl import (
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
from trl.rewards import think_format_reward from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space # Enable logging in a Hugging Face Space
@ -149,54 +147,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################ ################
# Training # Training
################ ################

View File

@ -89,6 +89,9 @@ vlm = [
"torchvision", "torchvision",
"num2words==0.5.14" "num2words==0.5.14"
] ]
math_verify = [
"math-verify>=0.5.2",
]
dev = [ dev = [
# bco # bco
"scikit-learn", "scikit-learn",

View File

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
from trl.rewards import get_soft_overlong_punishment, think_format_reward from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
from .testing_utils import TrlTestCase from .testing_utils import TrlTestCase, require_math_latex
class TestThinkFormatReward(TrlTestCase): class TestThinkFormatReward(TrlTestCase):
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
completion_ids = [[1] * 90] # 90 is between 80 and 100 completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids) rewards = reward_fn(completion_ids)
assert round(abs(rewards[0] - -0.5), 4) == 0 assert round(abs(rewards[0] - -0.5), 4) == 0
class TestAccuracyReward:
@require_math_latex
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}", "63/400"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 1.0
assert rewards[1] == 1.0
@require_math_latex
def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 0.0
@require_math_latex
def test_accuracy_reward_wrong_answer_no_latex(self):
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
completion = [[{"content": r"\boxed{3}"}]]
solution = ["6"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 0.0
@require_math_latex
def test_accuracy_reward_unparseable_gold(self):
"""Test accuracy_reward with an unparseable gold solution."""
completion = [
[{"content": "Answer is forty two."}],
[{"content": "Some other content."}],
[{"content": r"Answer is \boxed{42}."}],
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
[{"content": r"Answer is \boxed{42}."}],
[{"content": r"Answer is \boxed{42.3456}."}],
]
solution = [
"Answer is forty two.",
"Answer is forty three.",
"Answer is 42.0", # Decimal point
"Answer is 42 43 okay?", # Extra space
"Answer is 42",
r"Answer is \n\boxed{42}", # Newline in gold solution
"Answer is 42.34560", # Extra trailing zero
]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 1.0 # Should revert to exact text match
assert rewards[1] == 0.0
assert rewards[2] == 1.0
assert rewards[3] == 1.0
assert rewards[4] == 1.0
assert rewards[5] == 1.0
assert rewards[6] == 1.0 # Should ignore trailing zeros

View File

@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
from transformers.utils import is_peft_available, is_rich_available, is_vision_available from transformers.utils import is_peft_available, is_rich_available, is_vision_available
from trl import BaseBinaryJudge, BasePairwiseJudge from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available from trl.import_utils import (
is_joblib_available,
is_llm_blender_available,
is_math_verify_available,
is_mergekit_available,
is_vllm_available,
)
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft") require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")

View File

@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
_joblib_available = _is_package_available("joblib") _joblib_available = _is_package_available("joblib")
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) _liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
_llm_blender_available = _is_package_available("llm_blender") _llm_blender_available = _is_package_available("llm_blender")
_math_verify_available = _is_package_available("math_verify")
_mergekit_available = _is_package_available("mergekit") _mergekit_available = _is_package_available("mergekit")
_pydantic_available = _is_package_available("pydantic") _pydantic_available = _is_package_available("pydantic")
_requests_available = _is_package_available("requests") _requests_available = _is_package_available("requests")
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
return _llm_blender_available return _llm_blender_available
def is_math_verify_available() -> bool:
return _math_verify_available
def is_mergekit_available() -> bool: def is_mergekit_available() -> bool:
return _mergekit_available return _mergekit_available

View File

@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
_import_structure = { _import_structure = {
"accuracy_rewards": ["accuracy_reward"],
"format_rewards": ["think_format_reward"], "format_rewards": ["think_format_reward"],
"other_rewards": ["get_soft_overlong_punishment"], "other_rewards": ["get_soft_overlong_punishment"],
} }
if TYPE_CHECKING: if TYPE_CHECKING:
from .accuracy_rewards import accuracy_reward
from .format_rewards import think_format_reward from .format_rewards import think_format_reward
from .other_rewards import get_soft_overlong_punishment from .other_rewards import get_soft_overlong_punishment

View File

@ -0,0 +1,93 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from trl.import_utils import is_math_verify_available
if is_math_verify_available():
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
r"""
Reward function that checks if the completion is the same as the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
Args:
completions (`list[list[dict[str, str]]]`):
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
containing the key `"content"` with the value being the text of the completion.
solution: (`list[str]`):
List of the raw-text solutions to the questions/problems/prompts.
**kwargs:
Additional keyword arguments. This function does not use them, but they are required in the function
signature to ensure compatibility with trainers like [`GRPOTrainer`].
Example:
```python
>>> from trl.rewards import accuracy_reward
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completion = [
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
... ]
>>> accuracy_reward(completion, solution)
[1.0, 0.0]
```
"""
if not is_math_verify_available():
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Compute binary rewards if verifiable, `None` otherwise to skip this example
try:
reward = float(verify(gold_parsed, answer_parsed))
except Exception:
reward = None
else:
# If the gold solution is not parseable, we assign `None` to skip this example
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards

View File

@ -41,7 +41,7 @@ from trl import (
get_dataset, get_dataset,
get_peft_config, get_peft_config,
) )
from trl.rewards import get_soft_overlong_punishment, think_format_reward from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
reward_funcs_registry = { reward_funcs_registry = {
"accuracy_reward": accuracy_reward,
"think_format_reward": think_format_reward, "think_format_reward": think_format_reward,
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
} }
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs (`list[str]`, *optional*): reward_funcs (`list[str]`, *optional*):
Reward functions to use. Supported values are: Reward functions to use. Supported values are:
- `"accuracy_reward"`
- `"think_format_reward"` - `"think_format_reward"`
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs: Optional[list[str]] = field( reward_funcs: Optional[list[str]] = field(
default=None, default=None,
metadata={ metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, " "help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or " "`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
}, },

View File

@ -41,7 +41,7 @@ from trl import (
get_dataset, get_dataset,
get_peft_config, get_peft_config,
) )
from trl.rewards import get_soft_overlong_punishment, think_format_reward from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
reward_funcs_registry = { reward_funcs_registry = {
"accuracy_reward": accuracy_reward,
"think_format_reward": think_format_reward, "think_format_reward": think_format_reward,
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
} }
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
reward_funcs (`list[str]`, *optional*): reward_funcs (`list[str]`, *optional*):
Reward functions to use. Supported values are: Reward functions to use. Supported values are:
- `"accuracy_reward"`
- `"think_format_reward"` - `"think_format_reward"`
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
reward_funcs: Optional[list[str]] = field( reward_funcs: Optional[list[str]] = field(
default=None, default=None,
metadata={ metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, " "help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or " "`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
}, },