Compare commits

...

2 Commits

2 changed files with 25 additions and 1 deletions

View File

@ -436,8 +436,13 @@ class Trainer:
# Will reach this branch if the user has
# 1. Used `.from_pretrained` or `.from_config` to initialize their model
# 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand
# 3. Also not using quantization
# New models init such as `MyModel()` will not hit this step
if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True):
if (
not (hasattr(model, "hf_quantizer") and model.hf_quantizer.is_trainable)
and is_deepspeed_zero3_enabled()
and not getattr(model, "_transformers_zero3_init_used", True)
):
raise ValueError(
"Model was not initialized with `Zero-3` despite being configured for DeepSpeed Zero-3. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!"
)

View File

@ -42,9 +42,11 @@ from transformers.testing_utils import (
backend_device_count,
execute_subprocess_async,
mockenv_context,
require_bitsandbytes,
require_deepspeed,
require_optuna,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
@ -734,6 +736,23 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
assert trainer.is_deepspeed_enabled
assert model._transformers_zero3_init_used
@require_torch_gpu
@require_bitsandbytes
def test_missed_zero3_init_quantized(self):
from transformers import Trainer # noqa
with mockenv_context(**self.dist_env_1_gpu):
model = AutoModel.from_pretrained(T5_TINY, load_in_4bit=True)
training_args = TrainingArguments(
output_dir="./test_missed_zero3_init",
deepspeed=self.get_config_dict(ZERO3),
)
# Shouldn't raise an error in this case
_ = Trainer(
model=model,
args=training_args,
)
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]