mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
171 lines
6.4 KiB
Python
171 lines
6.4 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from itertools import chain
|
|
from typing import Optional
|
|
|
|
from datasets import load_dataset
|
|
from huggingface_hub import ModelCard
|
|
from transformers import HfArgumentParser
|
|
|
|
|
|
@dataclass
|
|
class ScriptArguments:
|
|
r"""
|
|
Arguments for the script.
|
|
|
|
Args:
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether to push the dataset to the Hugging Face Hub.
|
|
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
|
|
Hugging Face repository ID to push the dataset to.
|
|
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
|
Number of workers to use for dataset processing.
|
|
"""
|
|
|
|
push_to_hub: bool = field(
|
|
default=False,
|
|
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
|
)
|
|
repo_id: str = field(
|
|
default="trl-lib/math_shepherd",
|
|
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
|
)
|
|
dataset_num_proc: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "Number of workers to use for dataset processing."},
|
|
)
|
|
|
|
|
|
def process_example(example):
|
|
# Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
|
|
inputs = example["input"].replace("ки", "ⶻ")
|
|
|
|
# Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
|
|
indexes = [m.start() for m in re.finditer("ⶻ", inputs)]
|
|
|
|
# Sanity that all indexes are either "+" or "-"
|
|
assert all(example["label"][idx] in ["+", "-"] for idx in indexes)
|
|
|
|
# Get the labels
|
|
labels = [example["label"][idx] == "+" for idx in indexes]
|
|
|
|
# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
|
|
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
|
|
|
|
# Remove the last step (single ⶻ)
|
|
steps = steps[:-1]
|
|
|
|
# Get the prompt (first part) and completions (rest)
|
|
prompt = steps[0]
|
|
completions = steps[1:]
|
|
|
|
# Remove the heading "ⶻ" and the final whitespace from the completions
|
|
assert all(completion.startswith("ⶻ") for completion in completions)
|
|
completions = [completion[1:].strip() for completion in completions]
|
|
|
|
# At this point, we need to retrieve the first step from the prompt.
|
|
# First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
|
|
if prompt.startswith(
|
|
(
|
|
"Mr. Rocky",
|
|
"Parker",
|
|
"What is the smallest positive",
|
|
" The Myth",
|
|
"Let $\\mathbf{a}$",
|
|
"Find the arithmetic",
|
|
"Determine an ordered pair",
|
|
"Determine the ordered pair",
|
|
"At the Quill and Scroll stationery",
|
|
"Round to the nearest",
|
|
r"Calculate $\sqrt{10p}",
|
|
r"Simplify $\sqrt{28x}",
|
|
)
|
|
):
|
|
# Some spotted datasets errors where there is an annotation in the prompt: we remove it
|
|
labels = labels[1:]
|
|
|
|
# Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
|
|
# (less common) "?".
|
|
elif "Step 1:" in prompt:
|
|
prompt, first_step = prompt.split("Step 1:")
|
|
first_step = "Step 1:" + first_step
|
|
completions = [first_step.strip()] + completions
|
|
elif "step 1:" in prompt:
|
|
prompt, first_step = prompt.split("step 1:")
|
|
first_step = "step 1:" + first_step
|
|
completions = [first_step.strip()] + completions
|
|
elif "?" in prompt:
|
|
prompt, first_step = prompt.split("?")
|
|
prompt = prompt + "?"
|
|
completions = [first_step.strip()] + completions
|
|
else:
|
|
raise ValueError(f"Prompt can't be processed: {prompt}")
|
|
|
|
# Strip the prompt
|
|
prompt = prompt.strip()
|
|
|
|
# Sanity check that the length of the completions is the same as the length of the labels
|
|
assert len(completions) == len(labels)
|
|
|
|
return {"prompt": prompt, "completions": completions, "labels": labels}
|
|
|
|
|
|
model_card = ModelCard("""
|
|
---
|
|
tags: [trl]
|
|
---
|
|
|
|
# Math-Shepherd Dataset
|
|
|
|
## Summary
|
|
|
|
The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities.
|
|
|
|
## Data Structure
|
|
|
|
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
|
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)
|
|
|
|
Columns:
|
|
- `"prompt"`: The problem statement.
|
|
- `"completions"`: A list of reasoning steps generated to solve the problem.
|
|
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.
|
|
|
|
This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.
|
|
|
|
## Generation script
|
|
|
|
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py).
|
|
""")
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser(ScriptArguments)
|
|
script_args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")
|
|
|
|
dataset = dataset.map(
|
|
process_example,
|
|
remove_columns=["input", "label", "task"],
|
|
num_proc=script_args.dataset_num_proc,
|
|
)
|
|
dataset = dataset.train_test_split(test_size=0.05, seed=42)
|
|
|
|
if script_args.push_to_hub:
|
|
dataset.push_to_hub(script_args.repo_id)
|
|
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|