mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Use transformers
utilities when possible (#2064)
* use transformers' availability functions * require from transformers * rm file * fix no peft * fix import * don't alter _peft_available * fix require_diffusers * style * transformers>=4.40 and add back `is_liger_kernel_available`
This commit is contained in:
committed by
GitHub
parent
dc2bd07408
commit
07f0e687cb
@ -27,11 +27,12 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
@ -197,9 +198,9 @@ trainer.model.save_pretrained(output_dir)
|
||||
|
||||
# Free memory for merging weights
|
||||
del base_model
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -20,9 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
|
||||
@ -66,9 +64,9 @@ BATCH_SIZE = args.batch_size
|
||||
output_file = args.output_file
|
||||
max_new_tokens = args.max_new_tokens
|
||||
context_length = args.context_length
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
device = torch.xpu.current_device()
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
device = torch.npu.current_device()
|
||||
else:
|
||||
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
||||
@ -137,9 +135,9 @@ for model_id in tqdm(MODELS_TO_TEST):
|
||||
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")
|
||||
|
||||
model = None
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -33,10 +33,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser
|
||||
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -116,9 +115,9 @@ def aesthetic_scorer(hub_model_id, model_filename):
|
||||
model_filename=model_filename,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if is_npu_available():
|
||||
if is_torch_npu_available():
|
||||
scorer = scorer.npu()
|
||||
elif is_xpu_available():
|
||||
elif is_torch_xpu_available():
|
||||
scorer = scorer.xpu()
|
||||
else:
|
||||
scorer = scorer.cuda()
|
||||
|
@ -24,11 +24,10 @@ from accelerate import Accelerator, PartialState
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, HfArgumentParser, pipeline
|
||||
from transformers import AutoTokenizer, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available, pipeline
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
||||
from trl.core import LengthSampler
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
@ -142,9 +141,9 @@ ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=datase
|
||||
# to the same device as the PPOTrainer.
|
||||
device = ppo_trainer.accelerator.device
|
||||
if ppo_trainer.accelerator.num_processes == 1:
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:0"
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:0"
|
||||
else:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
|
||||
|
@ -19,11 +19,16 @@ from accelerate import PartialState
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
|
||||
input_min_text_length = 6
|
||||
@ -86,7 +91,7 @@ nf4_config = BitsAndBytesConfig(
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
device_map={"": "xpu:0"} if is_xpu_available() else {"": "npu:0"} if is_npu_available else {"": 0},
|
||||
device_map={"": "xpu:0"} if is_torch_xpu_available() else {"": "npu:0"} if is_torch_npu_available else {"": 0},
|
||||
peft_config=lora_config,
|
||||
quantization_config=nf4_config,
|
||||
reward_adapter=script_args.rm_adapter,
|
||||
|
@ -1,7 +1,7 @@
|
||||
datasets>=1.17.0
|
||||
torch>=1.4.0
|
||||
tqdm
|
||||
transformers>=4.39.0
|
||||
transformers>=4.40.0
|
||||
accelerate
|
||||
peft>=0.3.0
|
||||
tyro>=0.5.7
|
2
setup.py
2
setup.py
@ -77,7 +77,7 @@ __version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
"transformers>=4.39.0",
|
||||
"transformers>=4.40.0",
|
||||
"numpy>=1.18.2;platform_system!='Windows'",
|
||||
"numpy<2;platform_system=='Windows'",
|
||||
"accelerate",
|
||||
|
@ -21,10 +21,11 @@ from accelerate.utils.memory import release_memory
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from transformers.testing_utils import require_bitsandbytes, require_peft, require_torch_accelerator, torch_device
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import DPOConfig, DPOTrainer, is_peft_available
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
from ..testing_utils import require_bitsandbytes, require_non_cpu, require_peft, torch_device
|
||||
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
|
||||
|
||||
|
||||
@ -32,7 +33,7 @@ if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
@require_non_cpu
|
||||
@require_torch_accelerator
|
||||
class DPOTrainerSlowTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
|
||||
|
@ -21,17 +21,18 @@ from accelerate.utils.memory import release_memory
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_peft,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import SFTConfig, SFTTrainer, is_peft_available
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.models.utils import setup_chat_format
|
||||
|
||||
from ..testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_liger_kernel,
|
||||
require_multi_accelerator,
|
||||
require_non_cpu,
|
||||
require_peft,
|
||||
)
|
||||
from ..testing_utils import require_liger_kernel
|
||||
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
|
||||
|
||||
|
||||
@ -39,7 +40,7 @@ if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
@require_non_cpu
|
||||
@require_torch_accelerator
|
||||
class SFTTrainerSlowTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
|
||||
@ -270,7 +271,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, DEVICE_MAP_OPTIONS))
|
||||
)
|
||||
@require_multi_accelerator
|
||||
@require_torch_multi_accelerator
|
||||
def test_sft_trainer_transformers_mp_gc_device_map(
|
||||
self, model_name, packing, gradient_checkpointing_kwargs, device_map
|
||||
):
|
||||
|
@ -16,8 +16,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import is_diffusers_available, is_peft_available
|
||||
from trl import is_diffusers_available
|
||||
|
||||
from .testing_utils import require_diffusers
|
||||
|
||||
|
@ -20,11 +20,12 @@ from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
|
||||
from trl import BCOConfig, BCOTrainer
|
||||
from trl.trainer.bco_trainer import _process_tokens, _tokenize
|
||||
|
||||
from .testing_utils import require_no_wandb, require_peft
|
||||
from .testing_utils import require_no_wandb
|
||||
|
||||
|
||||
class BCOTrainerTester(unittest.TestCase):
|
||||
|
@ -18,11 +18,10 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
|
||||
from trl import CPOConfig, CPOTrainer
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class CPOTrainerTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -15,8 +15,9 @@ import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import is_diffusers_available, is_peft_available
|
||||
from trl import is_diffusers_available
|
||||
|
||||
from .testing_utils import require_diffusers
|
||||
|
||||
|
@ -28,11 +28,12 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from transformers.testing_utils import require_bitsandbytes, require_peft
|
||||
|
||||
from trl import DPOConfig, DPOTrainer, FDivergenceType
|
||||
from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens
|
||||
|
||||
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft
|
||||
from .testing_utils import require_no_wandb
|
||||
|
||||
|
||||
class TestBuildTokenizedAnswer(unittest.TestCase):
|
||||
|
@ -18,11 +18,12 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
|
||||
from trl import KTOConfig, KTOTrainer
|
||||
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
|
||||
|
||||
from .testing_utils import require_no_wandb, require_peft
|
||||
from .testing_utils import require_no_wandb
|
||||
|
||||
|
||||
class KTOTrainerTester(unittest.TestCase):
|
||||
|
@ -18,9 +18,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .testing_utils import is_peft_available, require_peft
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import import_utils
|
||||
|
||||
|
||||
class DummyDataset(torch.utils.data.Dataset):
|
||||
@ -70,27 +69,14 @@ EXPECTED_STATS = [
|
||||
]
|
||||
|
||||
|
||||
@require_peft
|
||||
class TestPeftDependancy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
|
||||
self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration"
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
|
||||
self.peft_model = get_peft_model(causal_lm_model, lora_config)
|
||||
|
||||
def test_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
@ -100,8 +86,11 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
|
||||
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
|
||||
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)
|
||||
import_utils._peft_available = _peft_available
|
||||
|
||||
def test_imports_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import ( # noqa: F401
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
@ -110,8 +99,11 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
PPOTrainer,
|
||||
PreTrainedModelWrapper,
|
||||
)
|
||||
import_utils._peft_available = _peft_available
|
||||
|
||||
def test_ppo_trainer_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
@ -154,3 +146,4 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
# check expected stats
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats
|
||||
import_utils._peft_available = _peft_available
|
||||
|
@ -18,11 +18,10 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
|
||||
from trl import ORPOConfig, ORPOTrainer
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class ORPOTrainerTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -17,15 +17,15 @@ import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.testing_utils import require_bitsandbytes, require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, is_peft_available
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
from .testing_utils import require_bitsandbytes, require_peft
|
||||
|
||||
|
||||
@require_peft
|
||||
class PeftModelTester(unittest.TestCase):
|
||||
|
@ -25,12 +25,12 @@ from huggingface_hub import HfApi
|
||||
from parameterized import parameterized
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_peft, require_torch_multi_accelerator
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
||||
from trl.core import respond_to_batch
|
||||
|
||||
from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER
|
||||
from .testing_utils import require_multi_accelerator, require_peft
|
||||
|
||||
|
||||
EXPECTED_STATS = [
|
||||
@ -1038,7 +1038,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
@require_peft
|
||||
@require_multi_accelerator
|
||||
@require_torch_multi_accelerator
|
||||
def test_peft_model_ppo_trainer_multi_gpu(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
@ -18,12 +18,11 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
|
||||
from transformers.testing_utils import require_peft
|
||||
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
from trl.trainer import compute_accuracy
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class RewardTrainerTester(unittest.TestCase):
|
||||
def test_accuracy_metrics(self):
|
||||
|
@ -26,14 +26,14 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
LlavaForConditionalGeneration,
|
||||
TrainingArguments,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_peft, require_vision
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.import_utils import is_peft_available, is_pil_available
|
||||
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
|
||||
|
||||
from .testing_utils import require_peft, requires_pil
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
|
||||
@ -51,7 +51,7 @@ def formatting_prompts_func_batched(example):
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
if is_pil_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
"trl-internal-testing/zen", "standard_prompt_completion"
|
||||
)
|
||||
|
||||
if is_pil_available():
|
||||
if is_vision_available():
|
||||
self.dummy_vsft_instruction_dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
@ -1159,7 +1159,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
assert len(trainer.train_dataset["input_ids"]) == len(self.conversational_lm_dataset["train"])
|
||||
assert len(trainer.eval_dataset["input_ids"]) == len(self.conversational_lm_dataset["test"])
|
||||
|
||||
@requires_pil
|
||||
@require_vision
|
||||
def test_sft_trainer_skip_prepare_dataset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = SFTConfig(
|
||||
@ -1210,7 +1210,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
)
|
||||
assert trainer.train_dataset.features == self.dummy_dataset.features
|
||||
|
||||
@requires_pil
|
||||
@require_vision
|
||||
def test_sft_trainer_llava(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = SFTConfig(
|
||||
|
@ -16,8 +16,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import is_peft_available
|
||||
from trl.trainer.model_config import ModelConfig
|
||||
from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
|
||||
|
||||
@ -25,8 +26,6 @@ from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class TestPad(unittest.TestCase):
|
||||
def test_pad_1_dim_left(self):
|
||||
|
@ -13,127 +13,27 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from accelerate.test_utils.testing import get_backend
|
||||
from transformers import is_wandb_available
|
||||
|
||||
from trl import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_liger_available,
|
||||
is_peft_available,
|
||||
is_pil_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
|
||||
|
||||
torch_device, device_count, memory_allocated_func = get_backend()
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires peft. Skips the test if peft is not available.
|
||||
"""
|
||||
if not is_peft_available():
|
||||
test_case = unittest.skip("test requires peft")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires bnb. Skips the test if bnb is not available.
|
||||
"""
|
||||
if not is_bitsandbytes_available():
|
||||
test_case = unittest.skip("test requires bnb")(test_case)
|
||||
return test_case
|
||||
from trl import is_diffusers_available, is_liger_kernel_available
|
||||
|
||||
|
||||
def require_diffusers(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
|
||||
"""
|
||||
if not is_diffusers_available():
|
||||
test_case = unittest.skip("test requires diffusers")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def requires_pil(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PIL. Skips the test if pil is not available.
|
||||
"""
|
||||
if not is_pil_available():
|
||||
test_case = unittest.skip("test requires PIL")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_wandb(test_case, required: bool = True):
|
||||
"""
|
||||
Decorator marking a test that requires wandb. Skips the test if wandb is not available.
|
||||
"""
|
||||
# XOR, i.e.:
|
||||
# skip if available and required = False and
|
||||
# skip if not available and required = True
|
||||
if is_wandb_available() ^ required:
|
||||
test_case = unittest.skip("test requires wandb")(test_case)
|
||||
return test_case
|
||||
return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case)
|
||||
|
||||
|
||||
def require_no_wandb(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires no wandb. Skips the test if wandb is available.
|
||||
"""
|
||||
return require_wandb(test_case, required=False)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs.
|
||||
"""
|
||||
if torch.cuda.device_count() < 2:
|
||||
test_case = unittest.skip("test requires multiple GPUs")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires GPUs. Skips the test if there is no GPU.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
test_case = unittest.skip("test requires GPU")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_multi_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs.
|
||||
"""
|
||||
if torch.xpu.device_count() < 2 and is_xpu_available():
|
||||
test_case = unittest.skip("test requires multiple XPUs")(test_case)
|
||||
return test_case
|
||||
return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case)
|
||||
|
||||
|
||||
def require_liger_kernel(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires liger-kernel. Also skip the test if there is no GPU.
|
||||
Decorator marking a test that requires liger_kernel. Skips the test if liger_kernel is not available.
|
||||
"""
|
||||
if not (torch.cuda.is_available() and is_liger_available()):
|
||||
test_case = unittest.skip("test requires GPU and liger-kernel")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_non_cpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
|
||||
hardware accelerator available.
|
||||
"""
|
||||
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
|
||||
|
||||
|
||||
def require_multi_accelerator(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires multiple hardware accelerators. These tests are skipped on a machine without
|
||||
multiple accelerators.
|
||||
"""
|
||||
return unittest.skipUnless(
|
||||
torch_device != "cpu" and device_count > 1, "test requires multiple hardware accelerators"
|
||||
)(test_case)
|
||||
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
|
||||
|
@ -32,16 +32,9 @@ _import_structure = {
|
||||
"BestOfNSampler",
|
||||
],
|
||||
"import_utils": [
|
||||
"is_bitsandbytes_available",
|
||||
"is_diffusers_available",
|
||||
"is_npu_available",
|
||||
"is_peft_available",
|
||||
"is_pil_available",
|
||||
"is_wandb_available",
|
||||
"is_xpu_available",
|
||||
"is_liger_kernel_available",
|
||||
"is_llmblender_available",
|
||||
"is_openai_available",
|
||||
"is_liger_available",
|
||||
],
|
||||
"models": [
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
@ -138,18 +131,7 @@ if TYPE_CHECKING:
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_npu_available,
|
||||
is_peft_available,
|
||||
is_pil_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
is_llmblender_available,
|
||||
is_openai_available,
|
||||
is_liger_available,
|
||||
)
|
||||
from .import_utils import is_diffusers_available, is_liger_kernel_available, is_llmblender_available
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
|
12
trl/core.py
12
trl/core.py
@ -22,9 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
from .import_utils import is_npu_available, is_xpu_available
|
||||
from transformers import TopKLogitsWarper, TopPLogitsWarper, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
try:
|
||||
@ -230,9 +228,9 @@ def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.manual_seed_all(seed)
|
||||
else:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
@ -258,11 +256,11 @@ class PPODecorators:
|
||||
def empty_device_cache(cls):
|
||||
yield
|
||||
if cls.optimize_device_cache:
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
gc.collect()
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
gc.collect()
|
||||
|
@ -14,28 +14,43 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
_is_python_greater_3_8 = False
|
||||
else:
|
||||
_is_python_greater_3_8 = True
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return find_spec("peft") is not None
|
||||
_diffusers_available = _is_package_available("diffusers")
|
||||
_unsloth_available = _is_package_available("unsloth")
|
||||
_rich_available = _is_package_available("rich")
|
||||
_liger_kernel_available = _is_package_available("liger_kernel")
|
||||
_llmblender_available = _is_package_available("llm_blender")
|
||||
|
||||
|
||||
def is_liger_available() -> bool:
|
||||
return find_spec("liger_kernel") is not None
|
||||
def is_diffusers_available() -> bool:
|
||||
return _diffusers_available
|
||||
|
||||
|
||||
def is_unsloth_available() -> bool:
|
||||
return find_spec("unsloth") is not None
|
||||
return _unsloth_available
|
||||
|
||||
|
||||
def is_rich_available() -> bool:
|
||||
return _rich_available
|
||||
|
||||
|
||||
def is_liger_kernel_available() -> bool: # replace by transformers.import_utils.is_liger_kernel_available() from v4.45
|
||||
return _liger_kernel_available
|
||||
|
||||
|
||||
def is_llmblender_available() -> bool:
|
||||
return _llmblender_available
|
||||
|
||||
|
||||
def is_accelerate_greater_20_0() -> bool:
|
||||
@ -74,72 +89,6 @@ def is_torch_greater_2_0() -> bool:
|
||||
return torch_version >= "2.0"
|
||||
|
||||
|
||||
def is_diffusers_available() -> bool:
|
||||
return find_spec("diffusers") is not None
|
||||
|
||||
|
||||
def is_pil_available() -> bool:
|
||||
return find_spec("PIL") is not None
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
import torch
|
||||
|
||||
# bnb can be imported without GPU but is not usable.
|
||||
return find_spec("bitsandbytes") is not None and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_torchvision_available() -> bool:
|
||||
return find_spec("torchvision") is not None
|
||||
|
||||
|
||||
def is_rich_available() -> bool:
|
||||
return find_spec("rich") is not None
|
||||
|
||||
|
||||
def is_wandb_available() -> bool:
|
||||
return find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_sklearn_available() -> bool:
|
||||
return find_spec("sklearn") is not None
|
||||
|
||||
|
||||
def is_llmblender_available() -> bool:
|
||||
return find_spec("llm_blender") is not None
|
||||
|
||||
|
||||
def is_openai_available() -> bool:
|
||||
return find_spec("openai") is not None
|
||||
|
||||
|
||||
def is_xpu_available() -> bool:
|
||||
if is_accelerate_greater_20_0():
|
||||
import accelerate
|
||||
|
||||
return accelerate.utils.is_xpu_available()
|
||||
else:
|
||||
if find_spec("intel_extension_for_pytorch") is None:
|
||||
return False
|
||||
try:
|
||||
import torch
|
||||
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
|
||||
if find_spec("torch") is None or find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
|
@ -18,9 +18,7 @@ import torch.nn as nn
|
||||
import torchvision
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from transformers import CLIPModel
|
||||
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
from transformers import CLIPModel, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -82,9 +80,9 @@ def aesthetic_scorer(hub_model_id, model_filename):
|
||||
model_filename=model_filename,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if is_npu_available():
|
||||
if is_torch_npu_available():
|
||||
scorer = scorer.npu()
|
||||
elif is_xpu_available():
|
||||
elif is_torch_xpu_available():
|
||||
scorer = scorer.xpu()
|
||||
else:
|
||||
scorer = scorer.cuda()
|
||||
|
@ -28,9 +28,10 @@ from huggingface_hub.utils import (
|
||||
RepositoryNotFoundError,
|
||||
)
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from transformers import GenerationMixin, PreTrainedModel
|
||||
from transformers import GenerationMixin, PreTrainedModel, is_torch_npu_available, is_torch_xpu_available
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available
|
||||
from ..import_utils import is_transformers_greater_than
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -403,9 +404,9 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
The current device.
|
||||
"""
|
||||
state = PartialState()
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
return f"xpu:{state.local_process_index}"
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
return f"npu:{state.local_process_index}"
|
||||
else:
|
||||
return state.local_process_index if torch.cuda.is_available() else "cpu"
|
||||
|
@ -24,9 +24,9 @@ import torch
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..core import randn_tensor
|
||||
from ..import_utils import is_peft_available
|
||||
from .sd_utils import convert_state_dict_to_diffusers
|
||||
|
||||
|
||||
|
@ -13,9 +13,8 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from ..import_utils import is_npu_available, is_xpu_available
|
||||
from .modeling_base import PreTrainedModelWrapper
|
||||
|
||||
|
||||
@ -251,9 +250,9 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
|
||||
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
|
||||
if isinstance(first_device, int):
|
||||
if is_npu_available():
|
||||
if is_torch_npu_available():
|
||||
first_device = f"npu:{first_device}"
|
||||
elif is_xpu_available():
|
||||
elif is_torch_xpu_available():
|
||||
first_device = f"xpu:{first_device}"
|
||||
else:
|
||||
first_device = f"cuda:{first_device}"
|
||||
|
@ -18,8 +18,9 @@ import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional, Tuple
|
||||
|
||||
from transformers import is_bitsandbytes_available, is_torchvision_available
|
||||
|
||||
from ..core import flatten_dict
|
||||
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -39,11 +39,13 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_sklearn_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput, has_length
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..import_utils import is_peft_available, is_sklearn_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from .bco_config import BCOConfig
|
||||
from .utils import (
|
||||
|
@ -29,12 +29,18 @@ import torch.nn.functional as F
|
||||
from accelerate import PartialState
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
DataCollator,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_torch_fx_proxy
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from .cpo_config import CPOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
|
@ -18,8 +18,9 @@ import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
from transformers import is_bitsandbytes_available, is_torchvision_available
|
||||
|
||||
from ..core import flatten_dict
|
||||
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -36,12 +36,13 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from .callbacks import SyncRefModelCallback
|
||||
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
|
||||
|
@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from transformers import AutoModelForCausalLM, GenerationConfig, PreTrainedModel
|
||||
|
||||
from ..import_utils import is_liger_available
|
||||
from ..import_utils import is_liger_kernel_available
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
@ -33,7 +33,7 @@ from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_liger_available():
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
||||
|
||||
|
||||
|
@ -28,9 +28,9 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..core import PPODecorators
|
||||
from ..import_utils import is_peft_available
|
||||
from .utils import trl_sanitze_kwargs_for_tagging
|
||||
|
||||
|
||||
|
@ -21,8 +21,9 @@ from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import InferenceClient
|
||||
from transformers.utils import is_openai_available
|
||||
|
||||
from ..import_utils import is_llmblender_available, is_openai_available
|
||||
from ..import_utils import is_llmblender_available
|
||||
|
||||
|
||||
if is_llmblender_available():
|
||||
|
@ -38,11 +38,12 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput, has_length
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from .kto_config import KTOConfig
|
||||
from .utils import (
|
||||
|
@ -25,17 +25,19 @@ from accelerate import PartialState
|
||||
from datasets import Dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from transformers import DataCollator, GenerationConfig, PreTrainedTokenizerBase, Trainer, TrainerCallback
|
||||
from transformers import (
|
||||
DataCollator,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
is_apex_available,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.trainer_utils import EvalPrediction, seed_worker
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
is_apex_available,
|
||||
is_sagemaker_mp_enabled,
|
||||
logging,
|
||||
)
|
||||
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging
|
||||
|
||||
from ..import_utils import is_peft_available
|
||||
from ..models import create_reference_model
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .judges import BasePairwiseJudge
|
||||
|
@ -32,12 +32,19 @@ from accelerate import PartialState
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
DataCollator,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
is_torch_xla_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_torch_fx_proxy, is_torch_xla_available
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from .orpo_config import ORPOConfig
|
||||
from .utils import (
|
||||
|
@ -20,12 +20,12 @@ from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from transformers import is_wandb_available
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from trl.trainer.utils import exact_div
|
||||
|
||||
from ..core import flatten_dict
|
||||
from ..import_utils import is_wandb_available
|
||||
|
||||
|
||||
JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]
|
||||
|
@ -35,6 +35,8 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from ..core import (
|
||||
@ -52,7 +54,7 @@ from ..core import (
|
||||
stack_dicts,
|
||||
stats_to_np,
|
||||
)
|
||||
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
|
||||
from ..import_utils import is_torch_greater_2_0
|
||||
from ..models import (
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
PreTrainedModelWrapper,
|
||||
@ -379,9 +381,9 @@ class PPOTrainer(BaseTrainer):
|
||||
if not getattr(self.model, "is_sequential_parallel", False):
|
||||
self.current_device = self.accelerator.device
|
||||
else:
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
self.current_device = torch.device("xpu:0")
|
||||
elif is_npu_available():
|
||||
elif is_torch_npu_available():
|
||||
self.current_device = torch.device("npu:0")
|
||||
else:
|
||||
self.current_device = torch.device("cuda:0")
|
||||
|
@ -27,8 +27,8 @@ from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase,
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..import_utils import is_peft_available
|
||||
from .reward_config import RewardConfig
|
||||
from .utils import (
|
||||
RewardDataCollatorWithPadding,
|
||||
|
@ -37,9 +37,10 @@ from transformers import (
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..extras.dataset_formatting import get_formatting_func_from_dataset
|
||||
from ..import_utils import is_liger_available, is_peft_available
|
||||
from ..import_utils import is_liger_kernel_available
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
@ -53,7 +54,7 @@ from .utils import (
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
if is_liger_available():
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.transformers import AutoLigerKernelForCausalLM
|
||||
|
||||
|
||||
|
@ -37,12 +37,13 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_peft_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
|
||||
from ..import_utils import is_unsloth_available
|
||||
from ..trainer.model_config import ModelConfig
|
||||
|
||||
|
||||
@ -902,7 +903,7 @@ def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesC
|
||||
|
||||
|
||||
def get_kbit_device_map() -> Optional[Dict[str, int]]:
|
||||
if is_xpu_available():
|
||||
if is_torch_xpu_available():
|
||||
return {"": f"xpu:{PartialState().local_process_index}"}
|
||||
elif torch.cuda.is_available():
|
||||
return {"": PartialState().local_process_index}
|
||||
|
@ -18,11 +18,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import PreTrainedTokenizerBase, TrainerCallback
|
||||
from transformers import PreTrainedTokenizerBase, TrainerCallback, is_apex_available
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_apex_available
|
||||
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .online_dpo_trainer import OnlineDPOTrainer
|
||||
|
Reference in New Issue
Block a user