Fix checkpoint api (#1714)

This commit is contained in:
Olatunji Ruwase
2022-01-21 06:32:48 -08:00
committed by GitHub
parent 4912e0ad7e
commit e40558ded2
2 changed files with 3 additions and 3 deletions

View File

@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module):
tag,
load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_fp16_weights()
self.optimizer._restore_from_bit16_weights()
return load_path, client_states

View File

@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
def _restore_from_bit16_weights(self):
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data)
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
self._restore_from_bit16_weights()
# Extract flattened partition for current rank from all partitions
def _get_flattened_partition(self, all_partition_states):