Fix the GPU memory usage of ZeRO-Offload (only update stage_1_and_2.py) (#7309)

Signed-off-by: Armin Zhu <mingzhengzhu1998@gmail.com>

Fix the memory usage of ZeRO-Offload with stage 1 and 2. Before the fix,
the memory usage is about 3x that of params_FP16. This is caused by the
H2D data copy is using different data type. Now the GPU memory usage is
about 1x params_FP16. And the H2D memory copy needs a 16bit pinned
memory buffer.
This commit is contained in:
Armin Zhu
2025-05-27 20:13:24 +08:00
committed by GitHub
parent b666844ffc
commit 17c8be0706

View File

@ -258,6 +258,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# that this process will update
self.single_partition_of_fp32_groups = []
# a 16-bit CPU param buffer for cpu offload
if self.cpu_offload:
self.param_buffer_of_bit16_for_cpu_offload_groups = []
# param partition info
# These are the parameters in each group that will not be updated by this process directly
@ -406,6 +410,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if self.cpu_offload:
weights_partition = get_accelerator().pin_memory(weights_partition)
temp_dtype = self.parallel_partitioned_bit16_groups[i][partition_id].dtype
temp_buffer_bit16 = torch.full(weights_partition.shape,
fill_value=0.0,
dtype=temp_dtype,
device=weights_partition.device)
if self.cpu_offload_pin_memory:
temp_pinned = get_accelerator().pin_memory(temp_buffer_bit16)
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_pinned)
else:
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_buffer_bit16)
self.single_partition_of_fp32_groups.append(weights_partition)
@ -1887,8 +1901,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(
fp32_partition.to(get_accelerator().current_device_name()).data)
bit16_partition_buffer = self.param_buffer_of_bit16_for_cpu_offload_groups[i]
bit16_partition_buffer.data.copy_(fp32_partition.data)
bit16_partitions[partition_id].data.copy_(bit16_partition_buffer.data, non_blocking=True)
self.timers(OPTIMIZER_STEP_TIMER).stop()
else: