mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix checkpoint api (#1714)
This commit is contained in:
@ -2460,7 +2460,7 @@ class DeepSpeedEngine(Module):
|
||||
tag,
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
if not success:
|
||||
self.optimizer._restore_from_fp16_weights()
|
||||
self.optimizer._restore_from_bit16_weights()
|
||||
|
||||
return load_path, client_states
|
||||
|
||||
|
@ -2971,13 +2971,13 @@ class DeepSpeedZeroOptimizer_Stage3(object):
|
||||
current.data.copy_(saved.data)
|
||||
|
||||
# Restore base optimizer fp32 weights from ZeRO fp16 weights
|
||||
def _restore_from_fp16_weights(self):
|
||||
def _restore_from_bit16_weights(self):
|
||||
for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat):
|
||||
fp32_partition.data.copy_(fp16_partitions.data)
|
||||
|
||||
# Refresh the fp32 master params from the fp16 copies.
|
||||
def refresh_fp32_params(self):
|
||||
self._restore_from_fp16_weights()
|
||||
self._restore_from_bit16_weights()
|
||||
|
||||
# Extract flattened partition for current rank from all partitions
|
||||
def _get_flattened_partition(self, all_partition_states):
|
||||
|
Reference in New Issue
Block a user