mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix the GPU memory usage of ZeRO-Offload (only update stage_1_and_2.py) (#7309)
Signed-off-by: Armin Zhu <mingzhengzhu1998@gmail.com> Fix the memory usage of ZeRO-Offload with stage 1 and 2. Before the fix, the memory usage is about 3x that of params_FP16. This is caused by the H2D data copy is using different data type. Now the GPU memory usage is about 1x params_FP16. And the H2D memory copy needs a 16bit pinned memory buffer.
This commit is contained in:
@ -258,6 +258,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
# that this process will update
|
||||
self.single_partition_of_fp32_groups = []
|
||||
|
||||
# a 16-bit CPU param buffer for cpu offload
|
||||
if self.cpu_offload:
|
||||
self.param_buffer_of_bit16_for_cpu_offload_groups = []
|
||||
|
||||
# param partition info
|
||||
|
||||
# These are the parameters in each group that will not be updated by this process directly
|
||||
@ -406,6 +410,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
if self.cpu_offload:
|
||||
weights_partition = get_accelerator().pin_memory(weights_partition)
|
||||
temp_dtype = self.parallel_partitioned_bit16_groups[i][partition_id].dtype
|
||||
temp_buffer_bit16 = torch.full(weights_partition.shape,
|
||||
fill_value=0.0,
|
||||
dtype=temp_dtype,
|
||||
device=weights_partition.device)
|
||||
if self.cpu_offload_pin_memory:
|
||||
temp_pinned = get_accelerator().pin_memory(temp_buffer_bit16)
|
||||
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_pinned)
|
||||
else:
|
||||
self.param_buffer_of_bit16_for_cpu_offload_groups.append(temp_buffer_bit16)
|
||||
|
||||
self.single_partition_of_fp32_groups.append(weights_partition)
|
||||
|
||||
@ -1887,8 +1901,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
# bit16_partitions[partition_id].data.copy_(fp32_partition.data)
|
||||
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
|
||||
fp32_partition = self.single_partition_of_fp32_groups[i]
|
||||
bit16_partitions[partition_id].data.copy_(
|
||||
fp32_partition.to(get_accelerator().current_device_name()).data)
|
||||
bit16_partition_buffer = self.param_buffer_of_bit16_for_cpu_offload_groups[i]
|
||||
bit16_partition_buffer.data.copy_(fp32_partition.data)
|
||||
bit16_partitions[partition_id].data.copy_(bit16_partition_buffer.data, non_blocking=True)
|
||||
|
||||
self.timers(OPTIMIZER_STEP_TIMER).stop()
|
||||
else:
|
||||
|
Reference in New Issue
Block a user