mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
SuperOffload Release (#7559)
This PR introduces **SuperOffload**—an optimizer designed for Superchips (Nvidia GH200 & GB200, AMD MI300A) with high CPU–GPU bandwidth. It enables **full fine-tuning** of **GPT-OSS-20B, Qwen3-14B, and Phi-4** on a single GH200 GPU, achieving up to **~500 TFLOPS**, using Hugging Face Transformers and DeepSpeed—no custom modeling code required. SuperOffload extends ZeRO-Offload with fine-grained control and CPUAdam rollback utilities, allowing GPU execution to overlap with CPUAdam. This reduces GPU idle time and improves overall efficiency. Key changes: - New SuperOffloadOptimizer_Stage3 optimizer. - C++/CUDA binding for adam_rollback to revert one optimization step. - Config additions including super_offload and cpuadam_cores_perc. A detailed blog and tutorial will be available soon. --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -164,3 +164,86 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def step_subgroup(self, subgroup_id: int, closure=None):
|
||||
"""Update the model parameters in a single subgroup (by index)."""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# Intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
|
||||
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
|
||||
|
||||
state = self.state[subgroup_id]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
|
||||
|
||||
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def rollback_subgroup(self, sub_group_id: int, closure=None):
|
||||
"""
|
||||
Rollback the optimizer state for a specific subgroup.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# Intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Validate subgroup state exists and is initialized
|
||||
if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0:
|
||||
raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} "
|
||||
f"as it has not been initialized.")
|
||||
|
||||
subgroup_state = self.state[sub_group_id]
|
||||
|
||||
# Check if we can rollback (step count must be > 0)
|
||||
if subgroup_state.get('step', 0) <= 0:
|
||||
raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: "
|
||||
f"step count is {subgroup_state.get('step', 0)}")
|
||||
|
||||
for _, group in enumerate(self.param_groups):
|
||||
for _, param in enumerate(group['params']):
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert param.device == device, (
|
||||
f"CPUAdam param is on {param.device} and must be 'cpu', "
|
||||
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")
|
||||
|
||||
# Decrement step count
|
||||
subgroup_state['step'] -= 1
|
||||
|
||||
# Extract hyperparameters
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
|
||||
group['eps'], group['weight_decay'], group['bias_correction'],
|
||||
param.data, param.grad.data, subgroup_state['exp_avg'],
|
||||
subgroup_state['exp_avg_sq'])
|
||||
return loss
|
||||
|
@ -886,6 +886,12 @@ class DeepSpeedEngine(Module):
|
||||
def zero_partial_offload(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
|
||||
|
||||
def super_offload(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "super_offload", False)
|
||||
|
||||
def cpuadam_cores_perc(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9)
|
||||
|
||||
def zero_sub_group_size(self):
|
||||
return self._config.zero_config.sub_group_size
|
||||
|
||||
@ -1826,7 +1832,10 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
optimizer = DeepSpeedZeroOptimizer_Stage3(
|
||||
from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3
|
||||
Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload(
|
||||
) else SuperOffloadOptimizer_Stage3
|
||||
optimizer = Stage3ZeroOptimizer(
|
||||
self.module,
|
||||
optimizer,
|
||||
timers=timers,
|
||||
@ -1864,6 +1873,7 @@ class DeepSpeedEngine(Module):
|
||||
zeropp_loco_param=self.zeropp_loco_param(),
|
||||
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
||||
enable_sanity_checks=self.is_sanity_checks_enabled(),
|
||||
cpuadam_cores_perc=self.cpuadam_cores_perc(),
|
||||
)
|
||||
|
||||
else:
|
||||
|
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
@ -0,0 +1,412 @@
|
||||
# Copyright (c) DeepSpeed Team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes
|
||||
from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
from deepspeed.utils.nvtx import instrument_w_nvtx
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
OPTIMIZER_STEP_TIMER = 'optimizer_step'
|
||||
|
||||
|
||||
class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
init_optimizer,
|
||||
timers,
|
||||
ds_config,
|
||||
static_loss_scale=1.0,
|
||||
dynamic_loss_scale=False,
|
||||
dynamic_loss_args=None,
|
||||
verbose=True,
|
||||
contiguous_gradients=True,
|
||||
reduce_bucket_size=500000000,
|
||||
prefetch_bucket_size=50000000,
|
||||
max_reuse_distance=1000000000,
|
||||
max_live_parameters=1000000000,
|
||||
param_persistence_threshold=100000,
|
||||
model_persistence_threshold=sys.maxsize,
|
||||
dp_process_group=None,
|
||||
reduce_scatter=True,
|
||||
overlap_comm=False,
|
||||
offload_optimizer_config=None,
|
||||
offload_param_config=None,
|
||||
sub_group_size=1000000000000,
|
||||
offload_ratio=0.0,
|
||||
mpu=None,
|
||||
clip_grad=0.0,
|
||||
gradient_accumulation_dtype=torch.float32,
|
||||
communication_data_type=torch.float16,
|
||||
postscale_gradients=True,
|
||||
gradient_predivide_factor=1.0,
|
||||
gradient_accumulation_steps=1,
|
||||
elastic_checkpoint=False,
|
||||
aio_config=None,
|
||||
all2all_process_group=None,
|
||||
zero_hpz_partition_size=1,
|
||||
zero_quantized_weights=False,
|
||||
zero_quantized_nontrainable_weights=False,
|
||||
zero_module_granularity_threshold=0,
|
||||
zeropp_loco_param=None,
|
||||
log_trace_cache_warnings=False,
|
||||
enable_sanity_checks=False,
|
||||
cpuadam_cores_perc=0.8,
|
||||
):
|
||||
|
||||
self.sub_group_to_param_num = {}
|
||||
self.params_in_ipg_bucket_buffer = []
|
||||
self._cur_bucket_index = -1
|
||||
self.async_cpuadam_num = 0
|
||||
self.max_grad_numel = 0
|
||||
|
||||
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
|
||||
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
|
||||
max_reuse_distance, max_live_parameters, param_persistence_threshold,
|
||||
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
|
||||
offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad,
|
||||
gradient_accumulation_dtype, communication_data_type, postscale_gradients,
|
||||
gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config,
|
||||
all2all_process_group, zero_hpz_partition_size, zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights, zero_module_granularity_threshold, zeropp_loco_param,
|
||||
log_trace_cache_warnings, enable_sanity_checks)
|
||||
|
||||
optimizer_config = {
|
||||
"lr": self.optimizer.param_groups[0]["lr"],
|
||||
"betas": self.optimizer.param_groups[0]["betas"],
|
||||
"eps": self.optimizer.param_groups[0]["eps"],
|
||||
"weight_decay": self.optimizer.param_groups[0]["weight_decay"],
|
||||
"amsgrad": self.optimizer.param_groups[0]["amsgrad"]
|
||||
}
|
||||
self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config,
|
||||
cpuadam_cores_perc=cpuadam_cores_perc,
|
||||
max_grad_numel=self.max_grad_numel)
|
||||
|
||||
def _create_fp16_sub_groups(self, params_group):
|
||||
|
||||
params_group_numel = sum([param.partition_numel() for param in params_group])
|
||||
sub_group_size = self.sub_group_size
|
||||
|
||||
if sub_group_size is None or sub_group_size >= params_group_numel:
|
||||
return [params_group]
|
||||
|
||||
sub_groups = []
|
||||
sub_group = []
|
||||
local_sub_group_size = 0
|
||||
|
||||
for param in params_group:
|
||||
sub_group.append(param)
|
||||
local_sub_group_size += param.partition_numel()
|
||||
|
||||
if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]):
|
||||
self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size)
|
||||
sub_groups.append(sub_group)
|
||||
self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group)
|
||||
|
||||
sub_group = []
|
||||
local_sub_group_size = 0
|
||||
|
||||
return sub_groups
|
||||
|
||||
def _optimizer_step(self, sub_group_id):
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
def step_with_gradscaler(optimizer):
|
||||
if self.torch_autocast_gradscaler:
|
||||
self.torch_autocast_gradscaler.step(optimizer)
|
||||
self.torch_autocast_gradscaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
cur_device = self.subgroup_to_device[sub_group_id]
|
||||
if cur_device != 'cpu':
|
||||
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||
step_with_gradscaler(self.backup_optimizer)
|
||||
self.backup_optimizer.param_groups[param_group_id]['params'] = []
|
||||
|
||||
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
|
||||
comm_dtype = self.get_param_comm_dtype(param)
|
||||
bucket = self.ipg_buckets[comm_dtype]
|
||||
i, _, _ = self.grad_position[self.get_param_id(param)]
|
||||
|
||||
if len(bucket.params) == 0:
|
||||
self._cur_bucket_index = i
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||
|
||||
# If this is a single-parameter sub-group, reduce immediately
|
||||
if self.sub_group_to_param_num[self._cur_bucket_index] == 1:
|
||||
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||
|
||||
elif i != self._cur_bucket_index:
|
||||
# Parameter belongs to different sub-group, buffer it
|
||||
self.params_in_ipg_bucket_buffer.append(param)
|
||||
else:
|
||||
# Parameter belongs to current bucket
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||
|
||||
# Check if bucket is complete
|
||||
if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||
|
||||
# Process buffered parameters
|
||||
while self.params_in_ipg_bucket_buffer:
|
||||
buffered_param = self.params_in_ipg_bucket_buffer.pop(0)
|
||||
ci, _, _ = self.grad_position[self.get_param_id(buffered_param)]
|
||||
self._cur_bucket_index = ci
|
||||
if getattr(buffered_param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
self._unflatten_partitioned_parameters(sub_group_id)
|
||||
return
|
||||
|
||||
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
|
||||
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].data)
|
||||
self._unflatten_partitioned_parameters(sub_group_id)
|
||||
else:
|
||||
self._partitioned_params_swap_out(sub_group_id)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param):
|
||||
"""Asynchronously update partitioned parameters with optimized values."""
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
|
||||
# print("[DEBUG] partition_grads called")
|
||||
buffers = []
|
||||
device_buffers = {}
|
||||
buffer_numel_min = {}
|
||||
buffer_numel_max = {}
|
||||
|
||||
for param, grad_partition in zip(params_to_release, grad_partitions):
|
||||
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
|
||||
|
||||
if self.is_gradient_accumulation_boundary:
|
||||
self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition)
|
||||
|
||||
buffer_numel = grad_partition.numel()
|
||||
buffers.append(grad_partition)
|
||||
|
||||
if i not in device_buffers:
|
||||
device_buffers[i] = []
|
||||
device_buffers[i].append(grad_partition)
|
||||
|
||||
if i not in buffer_numel_min:
|
||||
buffer_numel_min[i] = dest_offset
|
||||
buffer_numel_max[i] = dest_offset + buffer_numel
|
||||
else:
|
||||
buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset)
|
||||
buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel)
|
||||
|
||||
if self.is_gradient_accumulation_boundary:
|
||||
for i in buffer_numel_min.keys():
|
||||
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
|
||||
0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i])
|
||||
concatenated_buffer = torch.cat(device_buffers[i], dim=0).float()
|
||||
|
||||
if self.subgroup_to_device[i] == 'cpu':
|
||||
# Trigger asynchronous CPU optimization
|
||||
param_group_id = self.sub_group_to_group_id[i]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[i]
|
||||
|
||||
self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data,
|
||||
concatenated_buffer.data)
|
||||
self.async_cpuadam_num += 1
|
||||
|
||||
# Check for completed async operations
|
||||
result = self.superoffload_cpu_optimizer.get_result()
|
||||
if result is not None:
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
self.async_cpuadam_num -= 1
|
||||
|
||||
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||
else:
|
||||
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||
|
||||
# Clean up parameter gradients
|
||||
for param in params_to_release:
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
param.grad = None
|
||||
|
||||
@instrument_w_nvtx
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Not supporting closure.
|
||||
"""
|
||||
# Wait for any pending asynchronous CPU optimizer operations
|
||||
self._wait_for_async_operations()
|
||||
|
||||
self._pre_step()
|
||||
self._partition_all_parameters()
|
||||
|
||||
if self._overflow_check_and_loss_scale_update():
|
||||
self._handle_overflow_rollback()
|
||||
return
|
||||
|
||||
norm_groups = self._get_norm_groups()
|
||||
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
|
||||
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
|
||||
|
||||
timer_names = set()
|
||||
timer_names.add(OPTIMIZER_STEP_TIMER)
|
||||
self.timers(OPTIMIZER_STEP_TIMER).start()
|
||||
|
||||
if self.check_clip_grads(scaled_global_grad_norm):
|
||||
self._handle_gradient_clipping(scaled_global_grad_norm)
|
||||
|
||||
for sub_group_id, group in enumerate(self.fp16_groups):
|
||||
# Prepare optimizer states, gradients and fp32 parameters for update
|
||||
self._prepare_sub_group(sub_group_id, timer_names)
|
||||
|
||||
# Scale the fp32 gradients
|
||||
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||
|
||||
# Apply the optimizer step on the sub group and copy fp32 parameters to fp16
|
||||
self._optimizer_step(sub_group_id)
|
||||
|
||||
# Put fp16 parameters in appropriate location
|
||||
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
|
||||
|
||||
# Release memory or swap out optimizer states of fp32 parameters
|
||||
self._release_sub_group(sub_group_id, timer_names)
|
||||
|
||||
self.timers(OPTIMIZER_STEP_TIMER).stop()
|
||||
self._post_step(timer_names)
|
||||
|
||||
def _wait_for_async_operations(self, timeout_seconds=60):
|
||||
"""Wait for all pending asynchronous CPU optimizer operations to complete with timeout error.
|
||||
|
||||
Args:
|
||||
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||
"""
|
||||
if self.async_cpuadam_num > 0:
|
||||
logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...")
|
||||
if self.async_cpuadam_num == 0:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
initial_pending_ops = self.async_cpuadam_num
|
||||
|
||||
while self.async_cpuadam_num > 0:
|
||||
result = self.superoffload_cpu_optimizer.get_result()
|
||||
if result is None:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Throw error if we've been waiting longer than the timeout
|
||||
if elapsed_time >= timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||
f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. "
|
||||
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||
|
||||
time.sleep(0.001) # 1ms sleep
|
||||
continue
|
||||
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
self.async_cpuadam_num -= 1
|
||||
|
||||
def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60):
|
||||
"""Wait for a single asynchronous CPU-Adam optimizer operation with timeout.
|
||||
|
||||
Args:
|
||||
event_type (str): Type of operation expected ('adam_step' or 'rollback').
|
||||
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type)
|
||||
if result is not None:
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
break
|
||||
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Throw error if we've been waiting longer than the timeout
|
||||
if elapsed_time >= timeout_seconds:
|
||||
raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||
|
||||
time.sleep(0.001) # 1ms sleep
|
||||
|
||||
def _sync_cpu_optimizer_step(self,
|
||||
param_group_id: int,
|
||||
sub_group_id: int,
|
||||
fp32_param_data,
|
||||
fp32_grad_data,
|
||||
rollback: bool = False,
|
||||
timeout_seconds: int = 60):
|
||||
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||
self.superoffload_cpu_optimizer.async_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param_data,
|
||||
fp32_grad_data,
|
||||
rollback=rollback)
|
||||
# Wait for completion
|
||||
self._wait_for_single_async_result(event_type, timeout_seconds)
|
||||
|
||||
def _handle_overflow_rollback(self):
|
||||
"""Handle gradient overflow by rolling back CPU optimizer states."""
|
||||
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
# Trigger rollback
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=True)
|
||||
|
||||
def _handle_gradient_clipping(self, scaled_global_grad_norm):
|
||||
"""Handle gradient clipping with CPU optimizer rollback and re-optimization."""
|
||||
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
# Rollback CPU optimizer states
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=True)
|
||||
|
||||
# Clip gradients and re-optimize
|
||||
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=False)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def check_clip_grads(self, total_norm):
|
||||
"""Check if gradients need to be clipped based on the global norm."""
|
||||
unscaled_norm = total_norm / self.loss_scale
|
||||
return self.clip_grad and unscaled_norm > self.clip_grad
|
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright (c) DeepSpeed Team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
SuperOffload utilities for 1) running CPU optimizers in separate processes.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Any
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import psutil
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
class TaskKeys:
|
||||
PARAM_DATA = "param_data"
|
||||
PARAM_GRAD = "param_grad"
|
||||
PARAM_GROUP_ID = "param_group_id"
|
||||
SUB_GROUP_ID = "sub_group_id"
|
||||
ROLLBACK = "rollback"
|
||||
|
||||
|
||||
class ResultKeys:
|
||||
UPDATED_PARAM = "updated_param"
|
||||
EVENT_TYPE = "event_type"
|
||||
|
||||
|
||||
class EventTypes:
|
||||
ADAM_STEP = "adam_step"
|
||||
ROLLBACK = "rollback"
|
||||
|
||||
|
||||
def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue,
|
||||
optimizer_config: Dict[str, Any], max_grad_numel: int) -> None:
|
||||
"""
|
||||
This function runs in a separate process and continuously processes optimization
|
||||
tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and
|
||||
applies optimization steps to parameters received from the main process.
|
||||
|
||||
Args:
|
||||
param_queue: Queue for receiving optimization tasks
|
||||
result_queue: Queue for sending back optimization results
|
||||
optimizer_config: Configuration dictionary for the optimizer containing
|
||||
lr, betas, eps, weight_decay, and amsgrad parameters
|
||||
max_grad_numel: Maximum number of elements expected in gradient tensors
|
||||
"""
|
||||
# Initialize dummy parameter for optimizer creation
|
||||
cpu_tensor = torch.randn(1, device="cpu")
|
||||
cpu_param = torch.nn.Parameter(cpu_tensor)
|
||||
|
||||
try:
|
||||
optimizer = DeepSpeedCPUAdam([cpu_param],
|
||||
lr=optimizer_config["lr"],
|
||||
betas=optimizer_config["betas"],
|
||||
eps=optimizer_config["eps"],
|
||||
weight_decay=optimizer_config["weight_decay"],
|
||||
amsgrad=optimizer_config["amsgrad"])
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required optimizer config key: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
return
|
||||
|
||||
# Pre-allocate reusable pinned memory buffer for gradients
|
||||
pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
task = param_queue.get()
|
||||
|
||||
if task is None:
|
||||
logger.debug("Received termination signal, shutting down worker")
|
||||
break
|
||||
|
||||
param_data = task[TaskKeys.PARAM_DATA]
|
||||
param_grad = task[TaskKeys.PARAM_GRAD]
|
||||
param_group_id = task[TaskKeys.PARAM_GROUP_ID]
|
||||
sub_group_id = task[TaskKeys.SUB_GROUP_ID]
|
||||
rollback = task.get(TaskKeys.ROLLBACK, False)
|
||||
|
||||
logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}")
|
||||
|
||||
del task[TaskKeys.PARAM_DATA]
|
||||
del task[TaskKeys.PARAM_GRAD]
|
||||
task.clear()
|
||||
|
||||
grad_numel = param_grad.numel()
|
||||
if grad_numel > max_grad_numel:
|
||||
error_msg = (
|
||||
f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. "
|
||||
f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.")
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
|
||||
param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad)
|
||||
param_grad_cpu.copy_(param_grad, non_blocking=True)
|
||||
|
||||
fp32_param = torch.nn.Parameter(param_data)
|
||||
fp32_param.grad = param_grad_cpu
|
||||
|
||||
optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||
|
||||
if rollback:
|
||||
logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}")
|
||||
optimizer.rollback_subgroup(sub_group_id)
|
||||
else:
|
||||
optimizer.step_subgroup(sub_group_id)
|
||||
|
||||
# Send result back to main process
|
||||
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||
result_queue.put({
|
||||
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||
ResultKeys.UPDATED_PARAM: fp32_param.data,
|
||||
ResultKeys.EVENT_TYPE: event_type,
|
||||
})
|
||||
|
||||
# Clean up references to free memory
|
||||
optimizer.param_groups[param_group_id]['params'] = []
|
||||
del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required task key: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in worker process: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
|
||||
# Clean up pinned memory buffer
|
||||
if 'pinned_grad_buffer' in locals():
|
||||
del pinned_grad_buffer
|
||||
logger.debug("Cleaned up pinned memory buffer")
|
||||
|
||||
logger.debug("Worker process terminated")
|
||||
|
||||
|
||||
class SuperOffloadCPUOptimizer:
|
||||
|
||||
def __init__(self,
|
||||
optimizer_config: Dict[str, Any],
|
||||
cpuadam_cores_perc: float = 0.8,
|
||||
max_grad_numel: int = 1000000) -> None:
|
||||
if not 0 < cpuadam_cores_perc <= 1:
|
||||
raise ValueError("cpuadam_cores_perc must be between 0 and 1")
|
||||
|
||||
self.max_grad_numel = max_grad_numel
|
||||
self.mp_context = mp.get_context('spawn')
|
||||
self.param_queue = self.mp_context.SimpleQueue()
|
||||
self.result_queue = self.mp_context.SimpleQueue()
|
||||
|
||||
self.cpuadam_process = self.mp_context.Process(
|
||||
target=superoffload_optimizer_worker,
|
||||
args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel),
|
||||
daemon=True,
|
||||
)
|
||||
self.cpuadam_process.start()
|
||||
|
||||
# Set CPU affinity for better performance isolation
|
||||
self._set_cpu_affinity(cpuadam_cores_perc)
|
||||
|
||||
def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None:
|
||||
"""
|
||||
Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process.
|
||||
|
||||
Args:
|
||||
cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process
|
||||
"""
|
||||
try:
|
||||
current_process = psutil.Process()
|
||||
all_cores = current_process.cpu_affinity()
|
||||
num_cores = len(all_cores)
|
||||
|
||||
split_idx = int((1 - cpuadam_cores_perc) * num_cores)
|
||||
pt_cores = all_cores[:split_idx]
|
||||
cpuadam_cores = all_cores[split_idx:]
|
||||
|
||||
# Set affinity for main process (PyTorch)
|
||||
current_process.cpu_affinity(pt_cores)
|
||||
|
||||
# Set affinity for optimizer process (CPU Adam)
|
||||
optimizer_process = psutil.Process(self.cpuadam_process.pid)
|
||||
optimizer_process.cpu_affinity(cpuadam_cores)
|
||||
|
||||
logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, "
|
||||
f"Optimizer cores: {cpuadam_cores}")
|
||||
|
||||
except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e:
|
||||
logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error setting CPU affinity: {e}")
|
||||
|
||||
def async_step(self,
|
||||
param_group_id: int,
|
||||
sub_group_id: int,
|
||||
fp32_param: torch.Tensor,
|
||||
fp32_grad: torch.Tensor,
|
||||
rollback: bool = False) -> None:
|
||||
"""
|
||||
Queue parameter for optimization in the worker process.
|
||||
"""
|
||||
if not self.cpuadam_process.is_alive():
|
||||
raise RuntimeError("Worker process is not alive")
|
||||
|
||||
self.param_queue.put({
|
||||
TaskKeys.PARAM_DATA: fp32_param,
|
||||
TaskKeys.PARAM_GRAD: fp32_grad,
|
||||
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||
TaskKeys.ROLLBACK: rollback,
|
||||
})
|
||||
|
||||
def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get result from worker process with optional event type validation.
|
||||
|
||||
Args:
|
||||
expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback').
|
||||
If provided, validates that the result matches.
|
||||
"""
|
||||
if self.result_queue.empty():
|
||||
return None
|
||||
|
||||
result = self.result_queue.get()
|
||||
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Error in worker process: {result['error']}")
|
||||
|
||||
# Validate event type if expected_event_type is provided
|
||||
if expected_event_type is not None:
|
||||
result_event_type = result.get(ResultKeys.EVENT_TYPE)
|
||||
if result_event_type != expected_event_type:
|
||||
raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'")
|
||||
|
||||
return result
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Shutdown the worker process gracefully.
|
||||
|
||||
Sends termination signal to worker and waits for clean shutdown.
|
||||
If the process doesn't terminate within the timeout, it will be forcefully killed.
|
||||
"""
|
||||
if not self.cpuadam_process.is_alive():
|
||||
logger.debug("Worker process already terminated")
|
||||
return
|
||||
|
||||
# Send termination signal
|
||||
self.param_queue.put(None)
|
||||
|
||||
# Wait for graceful shutdown
|
||||
self.cpuadam_process.join(timeout=5)
|
||||
|
||||
if self.cpuadam_process.is_alive():
|
||||
logger.warning("Optimizer process did not terminate cleanly within timeout, "
|
||||
"forcefully terminating")
|
||||
self.cpuadam_process.terminate()
|
||||
self.cpuadam_process.join(timeout=2)
|
||||
|
||||
# Last resort: kill the process
|
||||
if self.cpuadam_process.is_alive():
|
||||
logger.error("Failed to terminate optimizer process, killing it")
|
||||
self.cpuadam_process.kill()
|
||||
self.cpuadam_process.join()
|
||||
|
||||
logger.debug("SuperOffload CPU optimizer closed successfully")
|
@ -93,6 +93,12 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
||||
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
||||
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
||||
|
||||
super_offload: bool = False
|
||||
""" Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3."""
|
||||
|
||||
cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0)
|
||||
""" Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_pipeline(self):
|
||||
pipeline = self.pipeline_read or self.pipeline_write
|
||||
|
@ -178,6 +178,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zeropp_loco_param=None,
|
||||
log_trace_cache_warnings=False,
|
||||
enable_sanity_checks=False,
|
||||
cpuadam_cores_perc=0.8,
|
||||
):
|
||||
see_memory_usage("Stage 3 initialize beginning", force=True)
|
||||
|
||||
@ -873,7 +874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
sub_group_size = len(self.fp16_partitioned_groups_flat)
|
||||
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
|
||||
for i in range(sub_group_size):
|
||||
if i < int(self.partial_offload * sub_group_size):
|
||||
if i >= int((1 - self.partial_offload) * sub_group_size):
|
||||
self.subgroup_to_device[i] = 'cpu'
|
||||
else:
|
||||
self.subgroup_to_device[i] = get_accelerator()._name
|
||||
|
Reference in New Issue
Block a user