!213 修复trl依赖问题,rm训练性能问题

Merge pull request !213 from 幽若/master-521
This commit is contained in:
2025-05-21 13:39:57 +00:00
committed by i-robot
parent 3c1a3b0fcb
commit 8de5038f58
4 changed files with 10 additions and 7 deletions

View File

@ -22,7 +22,7 @@ dependencies = ["tqdm",
"pyarrow == 16.1.0",
"openmind-hub >= 0.9.1",
"numpy < 2.0.0"]
requires-python = ">= 3.8, < 3.11"
requires-python = ">= 3.8, <= 3.11"
classifiers = [
"Development Status :: 1 - Planning",
"Intended Audience :: Developers",

View File

@ -33,7 +33,7 @@ if is_torch_available():
else:
from mindformers.trainer.utils import get_last_checkpoint
if is_torch_available() and is_transformers_available() and is_trl_available():
if is_trl_available():
from trl.trainer.dpo_config import DPOConfig
from trl.trainer.reward_config import RewardConfig

View File

@ -10,7 +10,7 @@
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
import argparse
import random
from dataclasses import dataclass
from typing import List, Any, Optional, Union, Dict
@ -51,6 +51,7 @@ class RewardDataCollatorWithPadding:
"""
tokenizer: PreTrainedTokenizerBase
args: argparse.Namespace
padding: Union[bool, str] = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
@ -61,10 +62,12 @@ class RewardDataCollatorWithPadding:
margin = []
# check if we have a margin. If we do, we need to batch it as well
has_margin = "margin" in features[0]
max_length = 0
if self.args.max_length:
max_length = self.args.max_length
else:
max_length = 1024
for feature in features:
# check if the keys are named as expected
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
keys_exist = (
"input_ids_chosen" in feature
and "input_ids_rejected" in feature
@ -132,7 +135,7 @@ def run_rm(
train_args = args.reward_args
train_args.remove_unused_columns = False
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer, args=args)
trainer = reward_trainer.RewardTrainer(
model=model,
args=train_args,

View File

@ -139,7 +139,7 @@ def is_transformers_available():
@lru_cache
def is_trl_available():
return _is_package_available("trl")
return _is_package_available("trl") and _is_package_available("transformers") and is_torch_available()
@lru_cache