disables ZeRO checkpoint loading path when stage=0 (#7586)

Fixes #7571 

When ZeRO is disabled (stage 0) and bf16 is enabled, the current guard
sets `load_zero_checkpoint=True`, which leads to `_load_zero_checkpoint`
and `_restore_from_bit16_weights()` being called even though no ZeRO
state exists.

This PR removes the `self.bfloat16_enabled()` condition so that
load_zero_checkpoint is tied strictly to `self.zero_optimization()`.

Stage 0 (BF16/FP16/FP32): cleanly skips ZeRO checkpoint path.

Stage ≥ 1: loads ZeRO partitioned optimizer state as before.

cc @sfc-gh-truwase

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Naveenraj Kamalakannan
2025-09-25 16:31:14 -04:00
committed by GitHub
parent 16c1bf429f
commit b75654001a

View File

@ -3115,7 +3115,7 @@ class DeepSpeedEngine(Module):
load_module_only=load_module_only,
custom_load_fn=custom_load_fn)
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
load_zero_checkpoint = load_path is not None and self.zero_optimization()
if load_zero_checkpoint:
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)