mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
4c71daf461 | |||
c1e9ea8ecf | |||
f66282424f | |||
14ef1aba15 |
@ -256,3 +256,30 @@ That's how `make test` is implemented (without the `pip install` line)!
|
||||
|
||||
You can specify a smaller set of tests to test only the feature
|
||||
you're working on.
|
||||
|
||||
### Deprecation and Backward Compatibility
|
||||
|
||||
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
warnings.warn(
|
||||
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
||||
"Please use the `Trainer.bar` class instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
```
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
@ -1,5 +1,11 @@
|
||||
# Judges
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
Make sure to have installed the required dependencies by running:
|
||||
|
4
setup.py
4
setup.py
@ -73,13 +73,13 @@ import os
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.12.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.12.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"accelerate>=0.34.0",
|
||||
"datasets>=2.21.0",
|
||||
"rich", # rich shouldn't be a required package for trl, we should remove it from here
|
||||
"transformers>=4.46.0",
|
||||
"transformers<4.47.0",
|
||||
]
|
||||
EXTRAS = {
|
||||
# Windows support is partially supported with DeepSpeed https://github.com/microsoft/DeepSpeed/tree/master#windows
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.12.0.dev0"
|
||||
__version__ = "0.12.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -48,6 +48,7 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput, has_length
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
@ -317,6 +318,7 @@ class BCOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "bco"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
|
@ -44,6 +44,7 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from .cpo_config import CPOConfig
|
||||
@ -103,6 +104,7 @@ class CPOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "cpo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -213,7 +213,7 @@ class DPOTrainer(Trainer):
|
||||
],
|
||||
custom_message="Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.",
|
||||
)
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -37,6 +37,7 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_liger_kernel_available, is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
@ -61,6 +62,7 @@ if is_wandb_available():
|
||||
class GKDTrainer(SFTTrainer):
|
||||
_tag_names = ["trl", "gkd"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..core import PPODecorators
|
||||
from .utils import generate_model_card
|
||||
@ -80,6 +81,7 @@ class IterativeSFTTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "iterative-sft"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[PreTrainedModel] = None,
|
||||
|
@ -47,6 +47,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvalLoopOutput, has_length
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
@ -312,6 +313,7 @@ class KTOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "kto"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_apex_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template
|
||||
from ..models.modeling_base import GeometricMixtureWrapper
|
||||
@ -93,6 +94,7 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
|
||||
_tag_names = ["trl", "nash-md"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
|
@ -44,6 +44,7 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, seed_worker
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from ..models import create_reference_model
|
||||
@ -125,6 +126,7 @@ class OnlineDPOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "online-dpo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
|
@ -48,6 +48,7 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..models import PreTrainedModelWrapper
|
||||
@ -114,6 +115,7 @@ class ORPOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "orpo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -44,6 +44,7 @@ from transformers import (
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..core import masked_mean, masked_whiten
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
@ -90,6 +91,7 @@ class PolicyAndValueWrapper(nn.Module):
|
||||
class PPOTrainer(Trainer):
|
||||
_tag_names = ["trl", "ppo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
config: PPOConfig,
|
||||
|
@ -39,6 +39,7 @@ from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from .reward_config import RewardConfig
|
||||
@ -80,6 +81,7 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase")
|
||||
class RewardTrainer(Trainer):
|
||||
_tag_names = ["trl", "reward-trainer"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
|
@ -44,6 +44,7 @@ from transformers import (
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..trainer.utils import (
|
||||
@ -71,6 +72,7 @@ INVALID_LOGPROB = 1.0
|
||||
class RLOOTrainer(Trainer):
|
||||
_tag_names = ["trl", "rloo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
config: RLOOConfig,
|
||||
|
@ -124,7 +124,7 @@ class SFTTrainer(Trainer):
|
||||
],
|
||||
custom_message="Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.",
|
||||
)
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.16.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
@ -92,6 +93,7 @@ class XPOTrainer(OnlineDPOTrainer):
|
||||
|
||||
_tag_names = ["trl", "xpo"]
|
||||
|
||||
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.13.0", raise_if_both_names=True)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
|
Reference in New Issue
Block a user