From 8e2d5516ca672d630ab4975ad675e09d3a7e50bd Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:01:07 +0100 Subject: [PATCH] Add accuracy reward (#4270) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec --- docs/source/rewards.md | 10 ++-- examples/scripts/grpo_vlm.py | 52 +---------------- examples/scripts/gspo.py | 52 +---------------- examples/scripts/gspo_vlm.py | 52 +---------------- examples/scripts/online_dpo_vlm.py | 52 +---------------- examples/scripts/rloo.py | 50 +--------------- examples/scripts/rloo_vlm.py | 52 +---------------- pyproject.toml | 3 + tests/test_rewards.py | 61 +++++++++++++++++++- tests/testing_utils.py | 9 ++- trl/import_utils.py | 5 ++ trl/rewards/__init__.py | 2 + trl/rewards/accuracy_rewards.py | 93 ++++++++++++++++++++++++++++++ trl/scripts/grpo.py | 6 +- trl/scripts/rloo.py | 6 +- 15 files changed, 189 insertions(+), 316 deletions(-) create mode 100644 trl/rewards/accuracy_rewards.py diff --git a/docs/source/rewards.md b/docs/source/rewards.md index d7f23a7ed..d18d6402f 100644 --- a/docs/source/rewards.md +++ b/docs/source/rewards.md @@ -2,14 +2,14 @@ This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`]. -## Format rewards +## accuracy_reward -### think_format_reward +[[autodoc]] rewards.accuracy_reward + +## think_format_reward [[autodoc]] rewards.think_format_reward -## Other rewards - -### get_soft_overlong_punishment +## get_soft_overlong_punishment [[autodoc]] rewards.get_soft_overlong_punishment diff --git a/examples/scripts/grpo_vlm.py b/examples/scripts/grpo_vlm.py index 44c87026b..62ddb975d 100644 --- a/examples/scripts/grpo_vlm.py +++ b/examples/scripts/grpo_vlm.py @@ -70,8 +70,6 @@ import os import torch from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from trl import ( GRPOConfig, @@ -83,7 +81,7 @@ from trl import ( get_peft_config, get_quantization_config, ) -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -149,54 +147,6 @@ if __name__ == "__main__": train_dataset = dataset["train"] eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None - ################ - # Reward Function for Training - ################ - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - ################ # Training ################ diff --git a/examples/scripts/gspo.py b/examples/scripts/gspo.py index 31dd4a987..3c587fdae 100644 --- a/examples/scripts/gspo.py +++ b/examples/scripts/gspo.py @@ -57,8 +57,6 @@ import os import torch from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from trl import ( GRPOConfig, @@ -70,7 +68,7 @@ from trl import ( get_peft_config, get_quantization_config, ) -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -120,54 +118,6 @@ if __name__ == "__main__": train_dataset = train_dataset.remove_columns(["messages", "problem"]) eval_dataset = eval_dataset.remove_columns(["messages", "problem"]) - ################ - # Reward Function for Training - ################ - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - ################ # Training ################ diff --git a/examples/scripts/gspo_vlm.py b/examples/scripts/gspo_vlm.py index d44d7fd9d..cff9e241c 100644 --- a/examples/scripts/gspo_vlm.py +++ b/examples/scripts/gspo_vlm.py @@ -57,8 +57,6 @@ import os import torch from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from trl import ( GRPOConfig, @@ -70,7 +68,7 @@ from trl import ( get_peft_config, get_quantization_config, ) -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -136,54 +134,6 @@ if __name__ == "__main__": train_dataset = dataset["train"] eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None - ################ - # Reward Function for Training - ################ - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - ################ # Training ################ diff --git a/examples/scripts/online_dpo_vlm.py b/examples/scripts/online_dpo_vlm.py index 39246bb42..62b4e9867 100644 --- a/examples/scripts/online_dpo_vlm.py +++ b/examples/scripts/online_dpo_vlm.py @@ -87,8 +87,6 @@ import os import torch import transformers from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from transformers import AutoConfig, AutoProcessor, GenerationConfig from trl import ( @@ -102,7 +100,7 @@ from trl import ( get_peft_config, get_quantization_config, ) -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -192,54 +190,6 @@ if __name__ == "__main__": train_dataset = dataset["train"] eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None - ################ - # Reward Function for Training (same as GRPO VLM) - ################ - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - ################ # Training ################ diff --git a/examples/scripts/rloo.py b/examples/scripts/rloo.py index bc599f7b9..abeabb45b 100644 --- a/examples/scripts/rloo.py +++ b/examples/scripts/rloo.py @@ -33,12 +33,10 @@ import os import torch from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from peft import LoraConfig from trl import RLOOConfig, RLOOTrainer -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -67,52 +65,6 @@ def main(): train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"]) eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"]) - # Reward function for training - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - # Training training_args = RLOOConfig( output_dir="Qwen3-0.6B-RLOO", diff --git a/examples/scripts/rloo_vlm.py b/examples/scripts/rloo_vlm.py index 2613ba7df..a98674db1 100644 --- a/examples/scripts/rloo_vlm.py +++ b/examples/scripts/rloo_vlm.py @@ -70,8 +70,6 @@ import os import torch from datasets import load_dataset -from latex2sympy2_extended import NormalizationConfig -from math_verify import LatexExtractionConfig, parse, verify from trl import ( ModelConfig, @@ -83,7 +81,7 @@ from trl import ( get_peft_config, get_quantization_config, ) -from trl.rewards import think_format_reward +from trl.rewards import accuracy_reward, think_format_reward # Enable logging in a Hugging Face Space @@ -149,54 +147,6 @@ if __name__ == "__main__": train_dataset = dataset["train"] eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None - ################ - # Reward Function for Training - ################ - def accuracy_reward(completions, solution: list[str], **kwargs): - """Reward function that checks if the completion matches the ground truth. - - If both gold and prediction are parseable → use math verification. - - If not parseable → compare as normalized text. - """ - rewards = [] - contents = [completion[0]["content"] for completion in completions] - for content, sol in zip(contents, solution): - try: - gold_parsed = parse(sol, extraction_mode="first_match") - except Exception: - gold_parsed = [] - - if len(gold_parsed) != 0: - # Try parsing predicted answer too - try: - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode="first_match", - ) - reward = float(verify(gold_parsed, answer_parsed)) - except Exception as e: - print(f"verify failed: {e}, answer: {content}, gold: {sol}") - reward = None - else: - # fallback to text match - reward = float(content.strip().lower() == sol.strip().lower()) - - rewards.append(reward) - - return rewards - ################ # Training ################ diff --git a/pyproject.toml b/pyproject.toml index 3ed8cc018..787415bec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,9 @@ vlm = [ "torchvision", "num2words==0.5.14" ] +math_verify = [ + "math-verify>=0.5.2", +] dev = [ # bco "scikit-learn", diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 0764ce5d9..f6341c688 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -13,9 +13,9 @@ # limitations under the License. -from trl.rewards import get_soft_overlong_punishment, think_format_reward +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_math_latex class TestThinkFormatReward(TrlTestCase): @@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward: completion_ids = [[1] * 90] # 90 is between 80 and 100 rewards = reward_fn(completion_ids) assert round(abs(rewards[0] - -0.5), 4) == 0 + + +class TestAccuracyReward: + @require_math_latex + def test_accuracy_reward_correct_answer(self): + """Test accuracy_reward with a correct answer.""" + completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]] + solution = [r"\frac{63}{400}", "63/400"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 1.0 + assert rewards[1] == 1.0 + + @require_math_latex + def test_accuracy_reward_wrong_answer(self): + """Test accuracy_reward with an incorrect answer.""" + completion = [[{"content": r"\boxed{\frac{64}{400}}"}]] + solution = [r"\frac{63}{400}"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 0.0 + + @require_math_latex + def test_accuracy_reward_wrong_answer_no_latex(self): + """Test accuracy_reward with an incorrect answer and gold solution with no latex.""" + completion = [[{"content": r"\boxed{3}"}]] + solution = ["6"] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 0.0 + + @require_math_latex + def test_accuracy_reward_unparseable_gold(self): + """Test accuracy_reward with an unparseable gold solution.""" + completion = [ + [{"content": "Answer is forty two."}], + [{"content": "Some other content."}], + [{"content": r"Answer is \boxed{42}."}], + [{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold + [{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold + [{"content": r"Answer is \boxed{42}."}], + [{"content": r"Answer is \boxed{42.3456}."}], + ] + solution = [ + "Answer is forty two.", + "Answer is forty three.", + "Answer is 42.0", # Decimal point + "Answer is 42 43 okay?", # Extra space + "Answer is 42", + r"Answer is \n\boxed{42}", # Newline in gold solution + "Answer is 42.34560", # Extra trailing zero + ] + rewards = accuracy_reward(completion, solution) + assert rewards[0] == 1.0 # Should revert to exact text match + assert rewards[1] == 0.0 + assert rewards[2] == 1.0 + assert rewards[3] == 1.0 + assert rewards[4] == 1.0 + assert rewards[5] == 1.0 + assert rewards[6] == 1.0 # Should ignore trailing zeros diff --git a/tests/testing_utils.py b/tests/testing_utils.py index fb2dec9cf..1d4992f3d 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device from transformers.utils import is_peft_available, is_rich_available, is_vision_available from trl import BaseBinaryJudge, BasePairwiseJudge -from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available +from trl.import_utils import ( + is_joblib_available, + is_llm_blender_available, + is_math_verify_available, + is_mergekit_available, + is_vllm_available, +) require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") +require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify") require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft") require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") diff --git a/trl/import_utils.py b/trl/import_utils.py index 10709dc54..4d8a9c84c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi") _joblib_available = _is_package_available("joblib") _liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) _llm_blender_available = _is_package_available("llm_blender") +_math_verify_available = _is_package_available("math_verify") _mergekit_available = _is_package_available("mergekit") _pydantic_available = _is_package_available("pydantic") _requests_available = _is_package_available("requests") @@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool: return _llm_blender_available +def is_math_verify_available() -> bool: + return _math_verify_available + + def is_mergekit_available() -> bool: return _mergekit_available diff --git a/trl/rewards/__init__.py b/trl/rewards/__init__.py index 4eb45a35c..b92384410 100644 --- a/trl/rewards/__init__.py +++ b/trl/rewards/__init__.py @@ -20,12 +20,14 @@ from ..import_utils import _LazyModule _import_structure = { + "accuracy_rewards": ["accuracy_reward"], "format_rewards": ["think_format_reward"], "other_rewards": ["get_soft_overlong_punishment"], } if TYPE_CHECKING: + from .accuracy_rewards import accuracy_reward from .format_rewards import think_format_reward from .other_rewards import get_soft_overlong_punishment diff --git a/trl/rewards/accuracy_rewards.py b/trl/rewards/accuracy_rewards.py new file mode 100644 index 000000000..8ad2e9d17 --- /dev/null +++ b/trl/rewards/accuracy_rewards.py @@ -0,0 +1,93 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from trl.import_utils import is_math_verify_available + + +if is_math_verify_available(): + from latex2sympy2_extended import NormalizationConfig + from math_verify import LatexExtractionConfig, parse, verify + + +def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]: + r""" + Reward function that checks if the completion is the same as the ground truth. + - If both gold and prediction are parseable → use math verification. + - If not parseable → compare as normalized text. + + Args: + completions (`list[list[dict[str, str]]]`): + List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary + containing the key `"content"` with the value being the text of the completion. + solution: (`list[str]`): + List of the raw-text solutions to the questions/problems/prompts. + **kwargs: + Additional keyword arguments. This function does not use them, but they are required in the function + signature to ensure compatibility with trainers like [`GRPOTrainer`]. + Example: + ```python + >>> from trl.rewards import accuracy_reward + + >>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"] + >>> completion = [ + ... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}], + ... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}], + ... ] + >>> accuracy_reward(completion, solution) + [1.0, 0.0] + ``` + """ + if not is_math_verify_available(): + raise ImportError("Please install the `math_verify` package to use accuracy_reward") + + contents = [completion[0]["content"] for completion in completions] + rewards = [] + for content, sol in zip(contents, solution): + gold_parsed = parse( + sol, + extraction_mode="first_match", + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Compute binary rewards if verifiable, `None` otherwise to skip this example + try: + reward = float(verify(gold_parsed, answer_parsed)) + except Exception: + reward = None + else: + # If the gold solution is not parseable, we assign `None` to skip this example + reward = float(content.strip().lower() == sol.strip().lower()) + rewards.append(reward) + + return rewards diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index de91941d0..5c93ca1e5 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -41,7 +41,7 @@ from trl import ( get_dataset, get_peft_config, ) -from trl.rewards import get_soft_overlong_punishment, think_format_reward +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward logger = logging.get_logger(__name__) @@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") reward_funcs_registry = { + "accuracy_reward": accuracy_reward, "think_format_reward": think_format_reward, "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), } @@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments): reward_funcs (`list[str]`, *optional*): Reward functions to use. Supported values are: + - `"accuracy_reward"` - `"think_format_reward"` - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). @@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments): reward_funcs: Optional[list[str]] = field( default=None, metadata={ - "help": "Reward functions to use. Supported values are: `think_format_reward`, " + "help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, " "`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or " "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." }, diff --git a/trl/scripts/rloo.py b/trl/scripts/rloo.py index 701a41746..a8438252b 100644 --- a/trl/scripts/rloo.py +++ b/trl/scripts/rloo.py @@ -41,7 +41,7 @@ from trl import ( get_dataset, get_peft_config, ) -from trl.rewards import get_soft_overlong_punishment, think_format_reward +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward logger = logging.get_logger(__name__) @@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") reward_funcs_registry = { + "accuracy_reward": accuracy_reward, "think_format_reward": think_format_reward, "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), } @@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments): reward_funcs (`list[str]`, *optional*): Reward functions to use. Supported values are: + - `"accuracy_reward"` - `"think_format_reward"` - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). @@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments): reward_funcs: Optional[list[str]] = field( default=None, metadata={ - "help": "Reward functions to use. Supported values are: `think_format_reward`, " + "help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, " "`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or " "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." },