mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Warnings pointing to RFC (#4224)
This commit is contained in:
committed by
GitHub
parent
c38cb69ec7
commit
7a0a615d50
@ -16,6 +16,7 @@ import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
@ -360,6 +361,13 @@ class BCOTrainer(BaseTrainer):
|
||||
embedding_func: Optional[Callable] = None,
|
||||
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
|
||||
raise ImportError(
|
||||
"BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
@ -142,6 +144,13 @@ class CPOTrainer(BaseTrainer):
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
|
@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -38,11 +40,7 @@ from ..models import prepare_deepspeed
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
from .sft_trainer import SFTTrainer
|
||||
from .utils import (
|
||||
DataCollatorForChatML,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
)
|
||||
from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -127,6 +125,13 @@ class GKDTrainer(SFTTrainer):
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
formatting_func: Optional[Callable] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
# Ensure Trainer does not drop non-signature columns used by the collator (e.g., "prompts")
|
||||
args.remove_unused_columns = False
|
||||
# Respect a user-provided data_collator; otherwise, provide a ChatML collator that
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
@ -353,6 +355,13 @@ class KTOTrainer(BaseTrainer):
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if type(args) is TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
|
||||
|
@ -206,6 +206,13 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
) -> None:
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
@ -144,6 +146,13 @@ class ORPOTrainer(BaseTrainer):
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
|
@ -17,6 +17,7 @@ import math
|
||||
import os
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
@ -159,6 +160,13 @@ class PPOTrainer(BaseTrainer):
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
) -> None:
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
||||
|
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
@ -117,6 +119,13 @@ class PRMTrainer(BaseTrainer):
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
||||
model = prepare_peft_model(model, peft_config, args)
|
||||
|
||||
|
@ -275,6 +275,13 @@ class RLOOTrainer(BaseTrainer):
|
||||
ref_policy=None,
|
||||
data_collator=None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
# Handle deprecated parameters
|
||||
if config is not None:
|
||||
warnings.warn(
|
||||
|
Reference in New Issue
Block a user