SuperOffload Release (#7559)

This PR introduces **SuperOffload**—an optimizer designed for Superchips
(Nvidia GH200 & GB200, AMD MI300A) with high CPU–GPU bandwidth. It
enables **full fine-tuning** of **GPT-OSS-20B, Qwen3-14B, and Phi-4** on
a single GH200 GPU, achieving up to **~500 TFLOPS**, using Hugging Face
Transformers and DeepSpeed—no custom modeling code required.

SuperOffload extends ZeRO-Offload with fine-grained control and CPUAdam
rollback utilities, allowing GPU execution to overlap with CPUAdam. This
reduces GPU idle time and improves overall efficiency.

Key changes:
- New SuperOffloadOptimizer_Stage3 optimizer.
- C++/CUDA binding for adam_rollback to revert one optimization step.
- Config additions including super_offload and cpuadam_cores_perc.

A detailed blog and tutorial will be available soon.

---------

Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Xinyu Lian
2025-09-24 08:09:23 -05:00
committed by GitHub
parent 17d80ce440
commit af56ed4d37
11 changed files with 1078 additions and 2 deletions

View File

@ -284,6 +284,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820)
34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996)
35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242)
36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [ASPLOS 2026](https://www.asplos-conference.org/asplos2026)
# Videos
1. DeepSpeed KDD 2020 Tutorial

View File

@ -8,6 +8,7 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
}

View File

@ -236,6 +236,102 @@ int ds_adam_step(int optimizer_id,
return 0;
}
void adamw_rollback_inplace(float* params,
const float* grads,
float* momentum,
float* variance,
size_t param_size,
float learning_rate,
float beta1,
float beta2,
float eps,
float weight_decay,
int& step_count)
{
const float lr = learning_rate;
const float lambda = weight_decay;
const float beta1_pow = std::pow(beta1, step_count);
const float beta2_pow = std::pow(beta2, step_count);
const float one_minus_beta1 = 1.0f - beta1;
const float one_minus_beta2 = 1.0f - beta2;
const float lr_lambda = lr * lambda;
const float one_minus_lr_lambda = 1.0f - lr_lambda;
#pragma omp parallel for
for (size_t i = 0; i < param_size; ++i) {
const float bias_correction1 = 1.0f - beta1_pow;
const float bias_correction2 = 1.0f - beta2_pow;
const float m_hat = momentum[i] / bias_correction1;
const float v_hat = variance[i] / bias_correction2;
const float denominator = std::sqrt(v_hat) + eps;
// Rollback parameter update
const float update = lr * m_hat / denominator;
float new_param = (params[i] + update) / one_minus_lr_lambda;
// Handle numerical instability
if (!std::isfinite(new_param)) { new_param = 0.0f; }
params[i] = new_param;
const float grad = grads[i];
momentum[i] = (momentum[i] - one_minus_beta1 * grad) / beta1;
variance[i] = (variance[i] - one_minus_beta2 * grad * grad) / beta2;
}
--step_count;
}
int ds_adam_rollback(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
try {
// Validate tensor types - rollback currently only supports float32
if (params.scalar_type() != torch::kFloat32 || grads.scalar_type() != torch::kFloat32 ||
exp_avg.scalar_type() != torch::kFloat32 ||
exp_avg_sq.scalar_type() != torch::kFloat32) {
printf("Error: Adam rollback currently only supports float32 tensors\n");
return -1;
}
float* params_ptr = params.data_ptr<float>();
const float* grads_ptr = grads.data_ptr<float>();
float* momentum_ptr = exp_avg.data_ptr<float>();
float* variance_ptr = exp_avg_sq.data_ptr<float>();
const size_t param_size = params.numel();
int step_count = static_cast<int>(step);
adamw_rollback_inplace(params_ptr,
grads_ptr,
momentum_ptr,
variance_ptr,
param_size,
lr,
beta1,
beta2,
epsilon,
weight_decay,
step_count);
return 0;
} catch (const std::exception& e) {
printf("Error in Adam rollback for optimizer #%d: %s\n", optimizer_id, e.what());
return -1;
}
}
int destroy_adam_optimizer(int optimizer_id)
{
s_optimizers.erase(optimizer_id);

View File

@ -217,4 +217,17 @@ int ds_adam_step(int optimizer_id,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq);
int ds_adam_rollback(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq);
int destroy_adam_optimizer(int optimizer_id);

View File

@ -164,3 +164,86 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'])
return loss
@torch.no_grad()
def step_subgroup(self, subgroup_id: int, closure=None):
"""Update the model parameters in a single subgroup (by index)."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# Intended device for step
device = torch.device('cpu')
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
state = self.state[subgroup_id]
if len(state) == 0:
state['step'] = 0
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
state['step'] += 1
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'])
return loss
@torch.no_grad()
def rollback_subgroup(self, sub_group_id: int, closure=None):
"""
Rollback the optimizer state for a specific subgroup.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# Intended device for step
device = torch.device('cpu')
# Validate subgroup state exists and is initialized
if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0:
raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} "
f"as it has not been initialized.")
subgroup_state = self.state[sub_group_id]
# Check if we can rollback (step count must be > 0)
if subgroup_state.get('step', 0) <= 0:
raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: "
f"step count is {subgroup_state.get('step', 0)}")
for _, group in enumerate(self.param_groups):
for _, param in enumerate(group['params']):
if param.grad is None:
continue
assert param.device == device, (
f"CPUAdam param is on {param.device} and must be 'cpu', "
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")
# Decrement step count
subgroup_state['step'] -= 1
# Extract hyperparameters
beta1, beta2 = group['betas']
self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
group['eps'], group['weight_decay'], group['bias_correction'],
param.data, param.grad.data, subgroup_state['exp_avg'],
subgroup_state['exp_avg_sq'])
return loss

View File

@ -886,6 +886,12 @@ class DeepSpeedEngine(Module):
def zero_partial_offload(self):
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
def super_offload(self):
return getattr(self._config.zero_config.offload_optimizer, "super_offload", False)
def cpuadam_cores_perc(self):
return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9)
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size
@ -1826,7 +1832,10 @@ class DeepSpeedEngine(Module):
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
optimizer = DeepSpeedZeroOptimizer_Stage3(
from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3
Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload(
) else SuperOffloadOptimizer_Stage3
optimizer = Stage3ZeroOptimizer(
self.module,
optimizer,
timers=timers,
@ -1864,6 +1873,7 @@ class DeepSpeedEngine(Module):
zeropp_loco_param=self.zeropp_loco_param(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
enable_sanity_checks=self.is_sanity_checks_enabled(),
cpuadam_cores_perc=self.cpuadam_cores_perc(),
)
else:

View File

@ -0,0 +1,412 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import sys
import time
import torch
from typing import List
from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes
from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from deepspeed.utils.nvtx import instrument_w_nvtx
from deepspeed.utils import logger
from deepspeed.accelerator import get_accelerator
OPTIMIZER_STEP_TIMER = 'optimizer_step'
class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3):
def __init__(
self,
module,
init_optimizer,
timers,
ds_config,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
sub_group_size=1000000000000,
offload_ratio=0.0,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
elastic_checkpoint=False,
aio_config=None,
all2all_process_group=None,
zero_hpz_partition_size=1,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
zeropp_loco_param=None,
log_trace_cache_warnings=False,
enable_sanity_checks=False,
cpuadam_cores_perc=0.8,
):
self.sub_group_to_param_num = {}
self.params_in_ipg_bucket_buffer = []
self._cur_bucket_index = -1
self.async_cpuadam_num = 0
self.max_grad_numel = 0
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
max_reuse_distance, max_live_parameters, param_persistence_threshold,
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad,
gradient_accumulation_dtype, communication_data_type, postscale_gradients,
gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config,
all2all_process_group, zero_hpz_partition_size, zero_quantized_weights,
zero_quantized_nontrainable_weights, zero_module_granularity_threshold, zeropp_loco_param,
log_trace_cache_warnings, enable_sanity_checks)
optimizer_config = {
"lr": self.optimizer.param_groups[0]["lr"],
"betas": self.optimizer.param_groups[0]["betas"],
"eps": self.optimizer.param_groups[0]["eps"],
"weight_decay": self.optimizer.param_groups[0]["weight_decay"],
"amsgrad": self.optimizer.param_groups[0]["amsgrad"]
}
self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config,
cpuadam_cores_perc=cpuadam_cores_perc,
max_grad_numel=self.max_grad_numel)
def _create_fp16_sub_groups(self, params_group):
params_group_numel = sum([param.partition_numel() for param in params_group])
sub_group_size = self.sub_group_size
if sub_group_size is None or sub_group_size >= params_group_numel:
return [params_group]
sub_groups = []
sub_group = []
local_sub_group_size = 0
for param in params_group:
sub_group.append(param)
local_sub_group_size += param.partition_numel()
if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]):
self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size)
sub_groups.append(sub_group)
self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group)
sub_group = []
local_sub_group_size = 0
return sub_groups
def _optimizer_step(self, sub_group_id):
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
def step_with_gradscaler(optimizer):
if self.torch_autocast_gradscaler:
self.torch_autocast_gradscaler.step(optimizer)
self.torch_autocast_gradscaler.update()
else:
optimizer.step()
cur_device = self.subgroup_to_device[sub_group_id]
if cur_device != 'cpu':
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
step_with_gradscaler(self.backup_optimizer)
self.backup_optimizer.param_groups[param_group_id]['params'] = []
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
comm_dtype = self.get_param_comm_dtype(param)
bucket = self.ipg_buckets[comm_dtype]
i, _, _ = self.grad_position[self.get_param_id(param)]
if len(bucket.params) == 0:
self._cur_bucket_index = i
if getattr(param, "ds_grad_is_ready", True):
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
# If this is a single-parameter sub-group, reduce immediately
if self.sub_group_to_param_num[self._cur_bucket_index] == 1:
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
elif i != self._cur_bucket_index:
# Parameter belongs to different sub-group, buffer it
self.params_in_ipg_bucket_buffer.append(param)
else:
# Parameter belongs to current bucket
if getattr(param, "ds_grad_is_ready", True):
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
# Check if bucket is complete
if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params):
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
# Process buffered parameters
while self.params_in_ipg_bucket_buffer:
buffered_param = self.params_in_ipg_bucket_buffer.pop(0)
ci, _, _ = self.grad_position[self.get_param_id(buffered_param)]
self._cur_bucket_index = ci
if getattr(buffered_param, "ds_grad_is_ready", True):
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param)
@instrument_w_nvtx
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
if self.subgroup_to_device[sub_group_id] == 'cpu':
self._unflatten_partitioned_parameters(sub_group_id)
return
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
self.fp32_partitioned_groups_flat[sub_group_id].data)
self._unflatten_partitioned_parameters(sub_group_id)
else:
self._partitioned_params_swap_out(sub_group_id)
@instrument_w_nvtx
def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param):
"""Asynchronously update partitioned parameters with optimized values."""
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True)
@instrument_w_nvtx
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
# print("[DEBUG] partition_grads called")
buffers = []
device_buffers = {}
buffer_numel_min = {}
buffer_numel_max = {}
for param, grad_partition in zip(params_to_release, grad_partitions):
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
if self.is_gradient_accumulation_boundary:
self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition)
buffer_numel = grad_partition.numel()
buffers.append(grad_partition)
if i not in device_buffers:
device_buffers[i] = []
device_buffers[i].append(grad_partition)
if i not in buffer_numel_min:
buffer_numel_min[i] = dest_offset
buffer_numel_max[i] = dest_offset + buffer_numel
else:
buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset)
buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel)
if self.is_gradient_accumulation_boundary:
for i in buffer_numel_min.keys():
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i])
concatenated_buffer = torch.cat(device_buffers[i], dim=0).float()
if self.subgroup_to_device[i] == 'cpu':
# Trigger asynchronous CPU optimization
param_group_id = self.sub_group_to_group_id[i]
fp32_param = self.fp32_partitioned_groups_flat[i]
self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data,
concatenated_buffer.data)
self.async_cpuadam_num += 1
# Check for completed async operations
result = self.superoffload_cpu_optimizer.get_result()
if result is not None:
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
result[ResultKeys.UPDATED_PARAM])
self.async_cpuadam_num -= 1
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
else:
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
# Clean up parameter gradients
for param in params_to_release:
if not get_accelerator().is_synchronized_device():
param.grad.record_stream(get_accelerator().current_stream())
param.grad = None
@instrument_w_nvtx
def step(self, closure=None):
"""
Not supporting closure.
"""
# Wait for any pending asynchronous CPU optimizer operations
self._wait_for_async_operations()
self._pre_step()
self._partition_all_parameters()
if self._overflow_check_and_loss_scale_update():
self._handle_overflow_rollback()
return
norm_groups = self._get_norm_groups()
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
timer_names = set()
timer_names.add(OPTIMIZER_STEP_TIMER)
self.timers(OPTIMIZER_STEP_TIMER).start()
if self.check_clip_grads(scaled_global_grad_norm):
self._handle_gradient_clipping(scaled_global_grad_norm)
for sub_group_id, group in enumerate(self.fp16_groups):
# Prepare optimizer states, gradients and fp32 parameters for update
self._prepare_sub_group(sub_group_id, timer_names)
# Scale the fp32 gradients
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
# Apply the optimizer step on the sub group and copy fp32 parameters to fp16
self._optimizer_step(sub_group_id)
# Put fp16 parameters in appropriate location
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
# Release memory or swap out optimizer states of fp32 parameters
self._release_sub_group(sub_group_id, timer_names)
self.timers(OPTIMIZER_STEP_TIMER).stop()
self._post_step(timer_names)
def _wait_for_async_operations(self, timeout_seconds=60):
"""Wait for all pending asynchronous CPU optimizer operations to complete with timeout error.
Args:
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
"""
if self.async_cpuadam_num > 0:
logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...")
if self.async_cpuadam_num == 0:
return
start_time = time.time()
initial_pending_ops = self.async_cpuadam_num
while self.async_cpuadam_num > 0:
result = self.superoffload_cpu_optimizer.get_result()
if result is None:
current_time = time.time()
elapsed_time = current_time - start_time
# Throw error if we've been waiting longer than the timeout
if elapsed_time >= timeout_seconds:
raise RuntimeError(
f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. "
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
time.sleep(0.001) # 1ms sleep
continue
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
result[ResultKeys.UPDATED_PARAM])
self.async_cpuadam_num -= 1
def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60):
"""Wait for a single asynchronous CPU-Adam optimizer operation with timeout.
Args:
event_type (str): Type of operation expected ('adam_step' or 'rollback').
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
"""
start_time = time.time()
while True:
result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type)
if result is not None:
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
result[ResultKeys.UPDATED_PARAM])
break
current_time = time.time()
elapsed_time = current_time - start_time
# Throw error if we've been waiting longer than the timeout
if elapsed_time >= timeout_seconds:
raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
time.sleep(0.001) # 1ms sleep
def _sync_cpu_optimizer_step(self,
param_group_id: int,
sub_group_id: int,
fp32_param_data,
fp32_grad_data,
rollback: bool = False,
timeout_seconds: int = 60):
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
self.superoffload_cpu_optimizer.async_step(param_group_id,
sub_group_id,
fp32_param_data,
fp32_grad_data,
rollback=rollback)
# Wait for completion
self._wait_for_single_async_result(event_type, timeout_seconds)
def _handle_overflow_rollback(self):
"""Handle gradient overflow by rolling back CPU optimizer states."""
for sub_group_id, _ in enumerate(self.fp16_groups):
if self.subgroup_to_device[sub_group_id] == 'cpu':
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
# Trigger rollback
self._sync_cpu_optimizer_step(param_group_id,
sub_group_id,
fp32_param.data,
fp32_param.grad.data,
rollback=True)
def _handle_gradient_clipping(self, scaled_global_grad_norm):
"""Handle gradient clipping with CPU optimizer rollback and re-optimization."""
for sub_group_id, _ in enumerate(self.fp16_groups):
if self.subgroup_to_device[sub_group_id] == 'cpu':
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
# Rollback CPU optimizer states
self._sync_cpu_optimizer_step(param_group_id,
sub_group_id,
fp32_param.data,
fp32_param.grad.data,
rollback=True)
# Clip gradients and re-optimize
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
self._sync_cpu_optimizer_step(param_group_id,
sub_group_id,
fp32_param.data,
fp32_param.grad.data,
rollback=False)
@instrument_w_nvtx
def check_clip_grads(self, total_norm):
"""Check if gradients need to be clipped based on the global norm."""
unscaled_norm = total_norm / self.loss_scale
return self.clip_grad and unscaled_norm > self.clip_grad

View File

@ -0,0 +1,273 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
SuperOffload utilities for 1) running CPU optimizers in separate processes.
"""
from typing import Dict, Optional, Any
import torch
import torch.multiprocessing as mp
import psutil
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.utils import logger
class TaskKeys:
PARAM_DATA = "param_data"
PARAM_GRAD = "param_grad"
PARAM_GROUP_ID = "param_group_id"
SUB_GROUP_ID = "sub_group_id"
ROLLBACK = "rollback"
class ResultKeys:
UPDATED_PARAM = "updated_param"
EVENT_TYPE = "event_type"
class EventTypes:
ADAM_STEP = "adam_step"
ROLLBACK = "rollback"
def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue,
optimizer_config: Dict[str, Any], max_grad_numel: int) -> None:
"""
This function runs in a separate process and continuously processes optimization
tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and
applies optimization steps to parameters received from the main process.
Args:
param_queue: Queue for receiving optimization tasks
result_queue: Queue for sending back optimization results
optimizer_config: Configuration dictionary for the optimizer containing
lr, betas, eps, weight_decay, and amsgrad parameters
max_grad_numel: Maximum number of elements expected in gradient tensors
"""
# Initialize dummy parameter for optimizer creation
cpu_tensor = torch.randn(1, device="cpu")
cpu_param = torch.nn.Parameter(cpu_tensor)
try:
optimizer = DeepSpeedCPUAdam([cpu_param],
lr=optimizer_config["lr"],
betas=optimizer_config["betas"],
eps=optimizer_config["eps"],
weight_decay=optimizer_config["weight_decay"],
amsgrad=optimizer_config["amsgrad"])
except KeyError as e:
error_msg = f"Missing required optimizer config key: {e}"
logger.error(error_msg)
result_queue.put({"error": error_msg})
return
# Pre-allocate reusable pinned memory buffer for gradients
pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True)
while True:
try:
task = param_queue.get()
if task is None:
logger.debug("Received termination signal, shutting down worker")
break
param_data = task[TaskKeys.PARAM_DATA]
param_grad = task[TaskKeys.PARAM_GRAD]
param_group_id = task[TaskKeys.PARAM_GROUP_ID]
sub_group_id = task[TaskKeys.SUB_GROUP_ID]
rollback = task.get(TaskKeys.ROLLBACK, False)
logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}")
del task[TaskKeys.PARAM_DATA]
del task[TaskKeys.PARAM_GRAD]
task.clear()
grad_numel = param_grad.numel()
if grad_numel > max_grad_numel:
error_msg = (
f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. "
f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.")
result_queue.put({"error": error_msg})
break
param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad)
param_grad_cpu.copy_(param_grad, non_blocking=True)
fp32_param = torch.nn.Parameter(param_data)
fp32_param.grad = param_grad_cpu
optimizer.param_groups[param_group_id]['params'] = [fp32_param]
if rollback:
logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}")
optimizer.rollback_subgroup(sub_group_id)
else:
optimizer.step_subgroup(sub_group_id)
# Send result back to main process
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
result_queue.put({
TaskKeys.PARAM_GROUP_ID: param_group_id,
TaskKeys.SUB_GROUP_ID: sub_group_id,
ResultKeys.UPDATED_PARAM: fp32_param.data,
ResultKeys.EVENT_TYPE: event_type,
})
# Clean up references to free memory
optimizer.param_groups[param_group_id]['params'] = []
del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data
except KeyError as e:
error_msg = f"Missing required task key: {e}"
logger.error(error_msg)
result_queue.put({"error": error_msg})
break
except Exception as e:
error_msg = f"Unexpected error in worker process: {e}"
logger.error(error_msg)
result_queue.put({"error": error_msg})
break
# Clean up pinned memory buffer
if 'pinned_grad_buffer' in locals():
del pinned_grad_buffer
logger.debug("Cleaned up pinned memory buffer")
logger.debug("Worker process terminated")
class SuperOffloadCPUOptimizer:
def __init__(self,
optimizer_config: Dict[str, Any],
cpuadam_cores_perc: float = 0.8,
max_grad_numel: int = 1000000) -> None:
if not 0 < cpuadam_cores_perc <= 1:
raise ValueError("cpuadam_cores_perc must be between 0 and 1")
self.max_grad_numel = max_grad_numel
self.mp_context = mp.get_context('spawn')
self.param_queue = self.mp_context.SimpleQueue()
self.result_queue = self.mp_context.SimpleQueue()
self.cpuadam_process = self.mp_context.Process(
target=superoffload_optimizer_worker,
args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel),
daemon=True,
)
self.cpuadam_process.start()
# Set CPU affinity for better performance isolation
self._set_cpu_affinity(cpuadam_cores_perc)
def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None:
"""
Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process.
Args:
cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process
"""
try:
current_process = psutil.Process()
all_cores = current_process.cpu_affinity()
num_cores = len(all_cores)
split_idx = int((1 - cpuadam_cores_perc) * num_cores)
pt_cores = all_cores[:split_idx]
cpuadam_cores = all_cores[split_idx:]
# Set affinity for main process (PyTorch)
current_process.cpu_affinity(pt_cores)
# Set affinity for optimizer process (CPU Adam)
optimizer_process = psutil.Process(self.cpuadam_process.pid)
optimizer_process.cpu_affinity(cpuadam_cores)
logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, "
f"Optimizer cores: {cpuadam_cores}")
except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e:
logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}")
except Exception as e:
logger.warning(f"Unexpected error setting CPU affinity: {e}")
def async_step(self,
param_group_id: int,
sub_group_id: int,
fp32_param: torch.Tensor,
fp32_grad: torch.Tensor,
rollback: bool = False) -> None:
"""
Queue parameter for optimization in the worker process.
"""
if not self.cpuadam_process.is_alive():
raise RuntimeError("Worker process is not alive")
self.param_queue.put({
TaskKeys.PARAM_DATA: fp32_param,
TaskKeys.PARAM_GRAD: fp32_grad,
TaskKeys.PARAM_GROUP_ID: param_group_id,
TaskKeys.SUB_GROUP_ID: sub_group_id,
TaskKeys.ROLLBACK: rollback,
})
def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]:
"""
Get result from worker process with optional event type validation.
Args:
expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback').
If provided, validates that the result matches.
"""
if self.result_queue.empty():
return None
result = self.result_queue.get()
if "error" in result:
raise RuntimeError(f"Error in worker process: {result['error']}")
# Validate event type if expected_event_type is provided
if expected_event_type is not None:
result_event_type = result.get(ResultKeys.EVENT_TYPE)
if result_event_type != expected_event_type:
raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'")
return result
def close(self) -> None:
"""
Shutdown the worker process gracefully.
Sends termination signal to worker and waits for clean shutdown.
If the process doesn't terminate within the timeout, it will be forcefully killed.
"""
if not self.cpuadam_process.is_alive():
logger.debug("Worker process already terminated")
return
# Send termination signal
self.param_queue.put(None)
# Wait for graceful shutdown
self.cpuadam_process.join(timeout=5)
if self.cpuadam_process.is_alive():
logger.warning("Optimizer process did not terminate cleanly within timeout, "
"forcefully terminating")
self.cpuadam_process.terminate()
self.cpuadam_process.join(timeout=2)
# Last resort: kill the process
if self.cpuadam_process.is_alive():
logger.error("Failed to terminate optimizer process, killing it")
self.cpuadam_process.kill()
self.cpuadam_process.join()
logger.debug("SuperOffload CPU optimizer closed successfully")

View File

@ -93,6 +93,12 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
ratio: float = Field(1.0, ge=0.0, le=1.0)
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
super_offload: bool = False
""" Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3."""
cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0)
""" Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True."""
@model_validator(mode="after")
def set_pipeline(self):
pipeline = self.pipeline_read or self.pipeline_write

View File

@ -178,6 +178,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
zeropp_loco_param=None,
log_trace_cache_warnings=False,
enable_sanity_checks=False,
cpuadam_cores_perc=0.8,
):
see_memory_usage("Stage 3 initialize beginning", force=True)
@ -873,7 +874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
sub_group_size = len(self.fp16_partitioned_groups_flat)
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
for i in range(sub_group_size):
if i < int(self.partial_offload * sub_group_size):
if i >= int((1 - self.partial_offload) * sub_group_size):
self.subgroup_to_device[i] = 'cpu'
else:
self.subgroup_to_device[i] = get_accelerator()._name

View File

@ -132,3 +132,183 @@ class TestCPUAdamGPUError(DistributedTest):
param.grad = torch.randn(model_size, device=device)
with pytest.raises(AssertionError):
optimizer.step()
class TestCPUAdamSubgroup(DistributedTest):
world_size = 1
reuse_dist_env = True
requires_cuda_env = False
if not get_accelerator().is_available():
init_distributed = False
set_dist_env = False
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
@pytest.mark.parametrize('model_size', [64, 128, 1024])
def test_step_subgroup_basic(self, dtype, model_size):
"""Test basic functionality of step_subgroup method."""
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
from deepspeed.ops.adam import DeepSpeedCPUAdam
# Create parameters
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
param = torch.nn.Parameter(cpu_data)
optimizer = DeepSpeedCPUAdam([param])
# Set gradient
param.grad = torch.randn(model_size, device='cpu').to(dtype)
# Store initial parameter values
initial_param = param.data.clone()
# Test step_subgroup with subgroup_id=0
subgroup_id = 0
optimizer.step_subgroup(subgroup_id)
# Verify parameter was updated
assert not torch.equal(param.data, initial_param), "Parameters should be updated after step_subgroup"
# Verify optimizer state was created for subgroup
assert subgroup_id in optimizer.state, "Optimizer state should be created for subgroup"
assert optimizer.state[subgroup_id]['step'] == 1, "Step count should be 1"
assert 'exp_avg' in optimizer.state[subgroup_id], "exp_avg should be in state"
assert 'exp_avg_sq' in optimizer.state[subgroup_id], "exp_avg_sq should be in state"
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
def test_step_subgroup_multiple_calls(self, dtype):
"""Test multiple calls to step_subgroup increment step count correctly."""
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
param = torch.nn.Parameter(cpu_data)
optimizer = DeepSpeedCPUAdam([param])
subgroup_id = 0
# Perform multiple steps
for step in range(1, 4):
param.grad = torch.randn(model_size, device='cpu').to(dtype)
optimizer.step_subgroup(subgroup_id)
# Verify step count increments
assert optimizer.state[subgroup_id]['step'] == step, f"Step count should be {step}"
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
def test_rollback_subgroup_basic(self, dtype):
"""Test basic functionality of rollback_subgroup method."""
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
param = torch.nn.Parameter(cpu_data)
optimizer = DeepSpeedCPUAdam([param])
subgroup_id = 0
param.grad = torch.randn(model_size, device='cpu').to(dtype)
# First, perform a step to initialize state
optimizer.step_subgroup(subgroup_id)
assert optimizer.state[subgroup_id]['step'] == 1
# Store parameter state after step
param_after_step = param.data.clone()
exp_avg_after_step = optimizer.state[subgroup_id]['exp_avg'].clone()
exp_avg_sq_after_step = optimizer.state[subgroup_id]['exp_avg_sq'].clone()
# Now rollback
optimizer.rollback_subgroup(subgroup_id)
# Verify step count decremented
assert optimizer.state[subgroup_id]['step'] == 0, "Step count should be decremented after rollback"
def test_rollback_subgroup_uninitialized_error(self):
"""Test that rollback_subgroup raises error for uninitialized subgroup."""
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
optimizer = DeepSpeedCPUAdam([param])
# Try to rollback uninitialized subgroup
with pytest.raises(RuntimeError, match="Cannot rollback optimizer state for sub_group_id 0"):
optimizer.rollback_subgroup(0)
def test_rollback_subgroup_zero_step_error(self):
"""Test that rollback_subgroup raises error when step count is already 0."""
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
optimizer = DeepSpeedCPUAdam([param])
subgroup_id = 0
param.grad = torch.randn(model_size, device='cpu')
# Initialize state by doing one step
optimizer.step_subgroup(subgroup_id)
# Rollback once (step should become 0)
optimizer.rollback_subgroup(subgroup_id)
assert optimizer.state[subgroup_id]['step'] == 0
# Try to rollback again - should raise error
with pytest.raises(RuntimeError, match="Cannot rollback sub_group_id 0: step count is 0"):
optimizer.rollback_subgroup(subgroup_id)
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
def test_step_rollback_sequence(self, dtype):
"""Test sequence of step_subgroup and rollback_subgroup operations."""
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
param = torch.nn.Parameter(cpu_data)
optimizer = DeepSpeedCPUAdam([param])
subgroup_id = 0
param.grad = torch.randn(model_size, device='cpu').to(dtype)
# Perform multiple steps
for step in range(1, 4):
optimizer.step_subgroup(subgroup_id)
assert optimizer.state[subgroup_id]['step'] == step
# Rollback steps one by one
for step in range(2, -1, -1):
optimizer.rollback_subgroup(subgroup_id)
assert optimizer.state[subgroup_id]['step'] == step
def test_multiple_subgroups(self):
"""Test that different subgroups maintain independent state."""
from deepspeed.ops.adam import DeepSpeedCPUAdam
model_size = 64
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
optimizer = DeepSpeedCPUAdam([param])
param.grad = torch.randn(model_size, device='cpu')
# Step different subgroups
optimizer.step_subgroup(0)
optimizer.step_subgroup(1)
optimizer.step_subgroup(0) # Step subgroup 0 again
# Verify independent step counts
assert optimizer.state[0]['step'] == 2, "Subgroup 0 should have step count 2"
assert optimizer.state[1]['step'] == 1, "Subgroup 1 should have step count 1"
# Rollback subgroup 0 only
optimizer.rollback_subgroup(0)
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"