From b75654001a2bb95b4205ac2deeab401a2524ee68 Mon Sep 17 00:00:00 2001 From: Naveenraj Kamalakannan Date: Thu, 25 Sep 2025 16:31:14 -0400 Subject: [PATCH] disables ZeRO checkpoint loading path when stage=0 (#7586) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #7571 When ZeRO is disabled (stage 0) and bf16 is enabled, the current guard sets `load_zero_checkpoint=True`, which leads to `_load_zero_checkpoint` and `_restore_from_bit16_weights()` being called even though no ZeRO state exists. This PR removes the `self.bfloat16_enabled()` condition so that load_zero_checkpoint is tied strictly to `self.zero_optimization()`. Stage 0 (BF16/FP16/FP32): cleanly skips ZeRO checkpoint path. Stage ≥ 1: loads ZeRO partitioned optimizer state as before. cc @sfc-gh-truwase Signed-off-by: Naveenraj Kamalakannan Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 3d345adcb..a5c106836 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3115,7 +3115,7 @@ class DeepSpeedEngine(Module): load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) + load_zero_checkpoint = load_path is not None and self.zero_optimization() if load_zero_checkpoint: if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)