@ -22,7 +22,7 @@ dependencies = ["tqdm",
|
||||
"pyarrow == 16.1.0",
|
||||
"openmind-hub >= 0.9.1",
|
||||
"numpy < 2.0.0"]
|
||||
requires-python = ">= 3.8, < 3.11"
|
||||
requires-python = ">= 3.8, <= 3.11"
|
||||
classifiers = [
|
||||
"Development Status :: 1 - Planning",
|
||||
"Intended Audience :: Developers",
|
||||
|
@ -33,7 +33,7 @@ if is_torch_available():
|
||||
else:
|
||||
from mindformers.trainer.utils import get_last_checkpoint
|
||||
|
||||
if is_torch_available() and is_transformers_available() and is_trl_available():
|
||||
if is_trl_available():
|
||||
from trl.trainer.dpo_config import DPOConfig
|
||||
from trl.trainer.reward_config import RewardConfig
|
||||
|
||||
|
@ -10,7 +10,7 @@
|
||||
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Union, Dict
|
||||
@ -51,6 +51,7 @@ class RewardDataCollatorWithPadding:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
args: argparse.Namespace
|
||||
padding: Union[bool, str] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
@ -61,10 +62,12 @@ class RewardDataCollatorWithPadding:
|
||||
margin = []
|
||||
# check if we have a margin. If we do, we need to batch it as well
|
||||
has_margin = "margin" in features[0]
|
||||
max_length = 0
|
||||
if self.args.max_length:
|
||||
max_length = self.args.max_length
|
||||
else:
|
||||
max_length = 1024
|
||||
for feature in features:
|
||||
# check if the keys are named as expected
|
||||
max_length = max(max_length, len(feature["input_ids_chosen"]), len(feature["input_ids_rejected"]))
|
||||
keys_exist = (
|
||||
"input_ids_chosen" in feature
|
||||
and "input_ids_rejected" in feature
|
||||
@ -132,7 +135,7 @@ def run_rm(
|
||||
train_args = args.reward_args
|
||||
train_args.remove_unused_columns = False
|
||||
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer)
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer=tokenizer, args=args)
|
||||
trainer = reward_trainer.RewardTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
|
@ -139,7 +139,7 @@ def is_transformers_available():
|
||||
|
||||
@lru_cache
|
||||
def is_trl_available():
|
||||
return _is_package_available("trl")
|
||||
return _is_package_available("trl") and _is_package_available("transformers") and is_torch_available()
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
Reference in New Issue
Block a user