mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
SuperOffload Release (#7559)
This PR introduces **SuperOffload**—an optimizer designed for Superchips (Nvidia GH200 & GB200, AMD MI300A) with high CPU–GPU bandwidth. It enables **full fine-tuning** of **GPT-OSS-20B, Qwen3-14B, and Phi-4** on a single GH200 GPU, achieving up to **~500 TFLOPS**, using Hugging Face Transformers and DeepSpeed—no custom modeling code required. SuperOffload extends ZeRO-Offload with fine-grained control and CPUAdam rollback utilities, allowing GPU execution to overlap with CPUAdam. This reduces GPU idle time and improves overall efficiency. Key changes: - New SuperOffloadOptimizer_Stage3 optimizer. - C++/CUDA binding for adam_rollback to revert one optimization step. - Config additions including super_offload and cpuadam_cores_perc. A detailed blog and tutorial will be available soon. --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -284,6 +284,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
|
|||||||
33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820)
|
33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820)
|
||||||
34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996)
|
34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996)
|
||||||
35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242)
|
35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242)
|
||||||
|
36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [ASPLOS 2026](https://www.asplos-conference.org/asplos2026)
|
||||||
|
|
||||||
# Videos
|
# Videos
|
||||||
1. DeepSpeed KDD 2020 Tutorial
|
1. DeepSpeed KDD 2020 Tutorial
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
{
|
{
|
||||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||||
|
m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)");
|
||||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||||
}
|
}
|
||||||
|
@ -236,6 +236,102 @@ int ds_adam_step(int optimizer_id,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void adamw_rollback_inplace(float* params,
|
||||||
|
const float* grads,
|
||||||
|
float* momentum,
|
||||||
|
float* variance,
|
||||||
|
size_t param_size,
|
||||||
|
float learning_rate,
|
||||||
|
float beta1,
|
||||||
|
float beta2,
|
||||||
|
float eps,
|
||||||
|
float weight_decay,
|
||||||
|
int& step_count)
|
||||||
|
{
|
||||||
|
const float lr = learning_rate;
|
||||||
|
const float lambda = weight_decay;
|
||||||
|
const float beta1_pow = std::pow(beta1, step_count);
|
||||||
|
const float beta2_pow = std::pow(beta2, step_count);
|
||||||
|
const float one_minus_beta1 = 1.0f - beta1;
|
||||||
|
const float one_minus_beta2 = 1.0f - beta2;
|
||||||
|
const float lr_lambda = lr * lambda;
|
||||||
|
const float one_minus_lr_lambda = 1.0f - lr_lambda;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t i = 0; i < param_size; ++i) {
|
||||||
|
const float bias_correction1 = 1.0f - beta1_pow;
|
||||||
|
const float bias_correction2 = 1.0f - beta2_pow;
|
||||||
|
|
||||||
|
const float m_hat = momentum[i] / bias_correction1;
|
||||||
|
const float v_hat = variance[i] / bias_correction2;
|
||||||
|
|
||||||
|
const float denominator = std::sqrt(v_hat) + eps;
|
||||||
|
|
||||||
|
// Rollback parameter update
|
||||||
|
const float update = lr * m_hat / denominator;
|
||||||
|
float new_param = (params[i] + update) / one_minus_lr_lambda;
|
||||||
|
|
||||||
|
// Handle numerical instability
|
||||||
|
if (!std::isfinite(new_param)) { new_param = 0.0f; }
|
||||||
|
|
||||||
|
params[i] = new_param;
|
||||||
|
|
||||||
|
const float grad = grads[i];
|
||||||
|
momentum[i] = (momentum[i] - one_minus_beta1 * grad) / beta1;
|
||||||
|
variance[i] = (variance[i] - one_minus_beta2 * grad * grad) / beta2;
|
||||||
|
}
|
||||||
|
|
||||||
|
--step_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ds_adam_rollback(int optimizer_id,
|
||||||
|
size_t step,
|
||||||
|
float lr,
|
||||||
|
float beta1,
|
||||||
|
float beta2,
|
||||||
|
float epsilon,
|
||||||
|
float weight_decay,
|
||||||
|
bool bias_correction,
|
||||||
|
torch::Tensor& params,
|
||||||
|
torch::Tensor& grads,
|
||||||
|
torch::Tensor& exp_avg,
|
||||||
|
torch::Tensor& exp_avg_sq)
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
// Validate tensor types - rollback currently only supports float32
|
||||||
|
if (params.scalar_type() != torch::kFloat32 || grads.scalar_type() != torch::kFloat32 ||
|
||||||
|
exp_avg.scalar_type() != torch::kFloat32 ||
|
||||||
|
exp_avg_sq.scalar_type() != torch::kFloat32) {
|
||||||
|
printf("Error: Adam rollback currently only supports float32 tensors\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* params_ptr = params.data_ptr<float>();
|
||||||
|
const float* grads_ptr = grads.data_ptr<float>();
|
||||||
|
float* momentum_ptr = exp_avg.data_ptr<float>();
|
||||||
|
float* variance_ptr = exp_avg_sq.data_ptr<float>();
|
||||||
|
const size_t param_size = params.numel();
|
||||||
|
int step_count = static_cast<int>(step);
|
||||||
|
|
||||||
|
adamw_rollback_inplace(params_ptr,
|
||||||
|
grads_ptr,
|
||||||
|
momentum_ptr,
|
||||||
|
variance_ptr,
|
||||||
|
param_size,
|
||||||
|
lr,
|
||||||
|
beta1,
|
||||||
|
beta2,
|
||||||
|
epsilon,
|
||||||
|
weight_decay,
|
||||||
|
step_count);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
printf("Error in Adam rollback for optimizer #%d: %s\n", optimizer_id, e.what());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int destroy_adam_optimizer(int optimizer_id)
|
int destroy_adam_optimizer(int optimizer_id)
|
||||||
{
|
{
|
||||||
s_optimizers.erase(optimizer_id);
|
s_optimizers.erase(optimizer_id);
|
||||||
|
@ -217,4 +217,17 @@ int ds_adam_step(int optimizer_id,
|
|||||||
torch::Tensor& exp_avg,
|
torch::Tensor& exp_avg,
|
||||||
torch::Tensor& exp_avg_sq);
|
torch::Tensor& exp_avg_sq);
|
||||||
|
|
||||||
|
int ds_adam_rollback(int optimizer_id,
|
||||||
|
size_t step,
|
||||||
|
float lr,
|
||||||
|
float beta1,
|
||||||
|
float beta2,
|
||||||
|
float epsilon,
|
||||||
|
float weight_decay,
|
||||||
|
bool bias_correction,
|
||||||
|
torch::Tensor& params,
|
||||||
|
torch::Tensor& grads,
|
||||||
|
torch::Tensor& exp_avg,
|
||||||
|
torch::Tensor& exp_avg_sq);
|
||||||
|
|
||||||
int destroy_adam_optimizer(int optimizer_id);
|
int destroy_adam_optimizer(int optimizer_id);
|
||||||
|
@ -164,3 +164,86 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
|||||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
state['exp_avg'], state['exp_avg_sq'])
|
state['exp_avg'], state['exp_avg_sq'])
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step_subgroup(self, subgroup_id: int, closure=None):
|
||||||
|
"""Update the model parameters in a single subgroup (by index)."""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
# Intended device for step
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
|
||||||
|
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
|
||||||
|
|
||||||
|
state = self.state[subgroup_id]
|
||||||
|
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
|
||||||
|
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
|
||||||
|
|
||||||
|
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
|
state['exp_avg'], state['exp_avg_sq'])
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rollback_subgroup(self, sub_group_id: int, closure=None):
|
||||||
|
"""
|
||||||
|
Rollback the optimizer state for a specific subgroup.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
# Intended device for step
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# Validate subgroup state exists and is initialized
|
||||||
|
if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0:
|
||||||
|
raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} "
|
||||||
|
f"as it has not been initialized.")
|
||||||
|
|
||||||
|
subgroup_state = self.state[sub_group_id]
|
||||||
|
|
||||||
|
# Check if we can rollback (step count must be > 0)
|
||||||
|
if subgroup_state.get('step', 0) <= 0:
|
||||||
|
raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: "
|
||||||
|
f"step count is {subgroup_state.get('step', 0)}")
|
||||||
|
|
||||||
|
for _, group in enumerate(self.param_groups):
|
||||||
|
for _, param in enumerate(group['params']):
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert param.device == device, (
|
||||||
|
f"CPUAdam param is on {param.device} and must be 'cpu', "
|
||||||
|
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")
|
||||||
|
|
||||||
|
# Decrement step count
|
||||||
|
subgroup_state['step'] -= 1
|
||||||
|
|
||||||
|
# Extract hyperparameters
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
|
||||||
|
group['eps'], group['weight_decay'], group['bias_correction'],
|
||||||
|
param.data, param.grad.data, subgroup_state['exp_avg'],
|
||||||
|
subgroup_state['exp_avg_sq'])
|
||||||
|
return loss
|
||||||
|
@ -886,6 +886,12 @@ class DeepSpeedEngine(Module):
|
|||||||
def zero_partial_offload(self):
|
def zero_partial_offload(self):
|
||||||
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
|
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
|
||||||
|
|
||||||
|
def super_offload(self):
|
||||||
|
return getattr(self._config.zero_config.offload_optimizer, "super_offload", False)
|
||||||
|
|
||||||
|
def cpuadam_cores_perc(self):
|
||||||
|
return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9)
|
||||||
|
|
||||||
def zero_sub_group_size(self):
|
def zero_sub_group_size(self):
|
||||||
return self._config.zero_config.sub_group_size
|
return self._config.zero_config.sub_group_size
|
||||||
|
|
||||||
@ -1826,7 +1832,10 @@ class DeepSpeedEngine(Module):
|
|||||||
|
|
||||||
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
|
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
|
||||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||||
optimizer = DeepSpeedZeroOptimizer_Stage3(
|
from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3
|
||||||
|
Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload(
|
||||||
|
) else SuperOffloadOptimizer_Stage3
|
||||||
|
optimizer = Stage3ZeroOptimizer(
|
||||||
self.module,
|
self.module,
|
||||||
optimizer,
|
optimizer,
|
||||||
timers=timers,
|
timers=timers,
|
||||||
@ -1864,6 +1873,7 @@ class DeepSpeedEngine(Module):
|
|||||||
zeropp_loco_param=self.zeropp_loco_param(),
|
zeropp_loco_param=self.zeropp_loco_param(),
|
||||||
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
||||||
enable_sanity_checks=self.is_sanity_checks_enabled(),
|
enable_sanity_checks=self.is_sanity_checks_enabled(),
|
||||||
|
cpuadam_cores_perc=self.cpuadam_cores_perc(),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
# Copyright (c) DeepSpeed Team.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# DeepSpeed Team
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes
|
||||||
|
from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor
|
||||||
|
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||||
|
from deepspeed.utils.nvtx import instrument_w_nvtx
|
||||||
|
from deepspeed.utils import logger
|
||||||
|
from deepspeed.accelerator import get_accelerator
|
||||||
|
|
||||||
|
OPTIMIZER_STEP_TIMER = 'optimizer_step'
|
||||||
|
|
||||||
|
|
||||||
|
class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module,
|
||||||
|
init_optimizer,
|
||||||
|
timers,
|
||||||
|
ds_config,
|
||||||
|
static_loss_scale=1.0,
|
||||||
|
dynamic_loss_scale=False,
|
||||||
|
dynamic_loss_args=None,
|
||||||
|
verbose=True,
|
||||||
|
contiguous_gradients=True,
|
||||||
|
reduce_bucket_size=500000000,
|
||||||
|
prefetch_bucket_size=50000000,
|
||||||
|
max_reuse_distance=1000000000,
|
||||||
|
max_live_parameters=1000000000,
|
||||||
|
param_persistence_threshold=100000,
|
||||||
|
model_persistence_threshold=sys.maxsize,
|
||||||
|
dp_process_group=None,
|
||||||
|
reduce_scatter=True,
|
||||||
|
overlap_comm=False,
|
||||||
|
offload_optimizer_config=None,
|
||||||
|
offload_param_config=None,
|
||||||
|
sub_group_size=1000000000000,
|
||||||
|
offload_ratio=0.0,
|
||||||
|
mpu=None,
|
||||||
|
clip_grad=0.0,
|
||||||
|
gradient_accumulation_dtype=torch.float32,
|
||||||
|
communication_data_type=torch.float16,
|
||||||
|
postscale_gradients=True,
|
||||||
|
gradient_predivide_factor=1.0,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
elastic_checkpoint=False,
|
||||||
|
aio_config=None,
|
||||||
|
all2all_process_group=None,
|
||||||
|
zero_hpz_partition_size=1,
|
||||||
|
zero_quantized_weights=False,
|
||||||
|
zero_quantized_nontrainable_weights=False,
|
||||||
|
zero_module_granularity_threshold=0,
|
||||||
|
zeropp_loco_param=None,
|
||||||
|
log_trace_cache_warnings=False,
|
||||||
|
enable_sanity_checks=False,
|
||||||
|
cpuadam_cores_perc=0.8,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.sub_group_to_param_num = {}
|
||||||
|
self.params_in_ipg_bucket_buffer = []
|
||||||
|
self._cur_bucket_index = -1
|
||||||
|
self.async_cpuadam_num = 0
|
||||||
|
self.max_grad_numel = 0
|
||||||
|
|
||||||
|
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
|
||||||
|
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
|
||||||
|
max_reuse_distance, max_live_parameters, param_persistence_threshold,
|
||||||
|
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
|
||||||
|
offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad,
|
||||||
|
gradient_accumulation_dtype, communication_data_type, postscale_gradients,
|
||||||
|
gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config,
|
||||||
|
all2all_process_group, zero_hpz_partition_size, zero_quantized_weights,
|
||||||
|
zero_quantized_nontrainable_weights, zero_module_granularity_threshold, zeropp_loco_param,
|
||||||
|
log_trace_cache_warnings, enable_sanity_checks)
|
||||||
|
|
||||||
|
optimizer_config = {
|
||||||
|
"lr": self.optimizer.param_groups[0]["lr"],
|
||||||
|
"betas": self.optimizer.param_groups[0]["betas"],
|
||||||
|
"eps": self.optimizer.param_groups[0]["eps"],
|
||||||
|
"weight_decay": self.optimizer.param_groups[0]["weight_decay"],
|
||||||
|
"amsgrad": self.optimizer.param_groups[0]["amsgrad"]
|
||||||
|
}
|
||||||
|
self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config,
|
||||||
|
cpuadam_cores_perc=cpuadam_cores_perc,
|
||||||
|
max_grad_numel=self.max_grad_numel)
|
||||||
|
|
||||||
|
def _create_fp16_sub_groups(self, params_group):
|
||||||
|
|
||||||
|
params_group_numel = sum([param.partition_numel() for param in params_group])
|
||||||
|
sub_group_size = self.sub_group_size
|
||||||
|
|
||||||
|
if sub_group_size is None or sub_group_size >= params_group_numel:
|
||||||
|
return [params_group]
|
||||||
|
|
||||||
|
sub_groups = []
|
||||||
|
sub_group = []
|
||||||
|
local_sub_group_size = 0
|
||||||
|
|
||||||
|
for param in params_group:
|
||||||
|
sub_group.append(param)
|
||||||
|
local_sub_group_size += param.partition_numel()
|
||||||
|
|
||||||
|
if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]):
|
||||||
|
self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size)
|
||||||
|
sub_groups.append(sub_group)
|
||||||
|
self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group)
|
||||||
|
|
||||||
|
sub_group = []
|
||||||
|
local_sub_group_size = 0
|
||||||
|
|
||||||
|
return sub_groups
|
||||||
|
|
||||||
|
def _optimizer_step(self, sub_group_id):
|
||||||
|
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||||
|
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||||
|
|
||||||
|
def step_with_gradscaler(optimizer):
|
||||||
|
if self.torch_autocast_gradscaler:
|
||||||
|
self.torch_autocast_gradscaler.step(optimizer)
|
||||||
|
self.torch_autocast_gradscaler.update()
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
cur_device = self.subgroup_to_device[sub_group_id]
|
||||||
|
if cur_device != 'cpu':
|
||||||
|
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||||
|
step_with_gradscaler(self.backup_optimizer)
|
||||||
|
self.backup_optimizer.param_groups[param_group_id]['params'] = []
|
||||||
|
|
||||||
|
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
|
||||||
|
comm_dtype = self.get_param_comm_dtype(param)
|
||||||
|
bucket = self.ipg_buckets[comm_dtype]
|
||||||
|
i, _, _ = self.grad_position[self.get_param_id(param)]
|
||||||
|
|
||||||
|
if len(bucket.params) == 0:
|
||||||
|
self._cur_bucket_index = i
|
||||||
|
if getattr(param, "ds_grad_is_ready", True):
|
||||||
|
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||||
|
|
||||||
|
# If this is a single-parameter sub-group, reduce immediately
|
||||||
|
if self.sub_group_to_param_num[self._cur_bucket_index] == 1:
|
||||||
|
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||||
|
|
||||||
|
elif i != self._cur_bucket_index:
|
||||||
|
# Parameter belongs to different sub-group, buffer it
|
||||||
|
self.params_in_ipg_bucket_buffer.append(param)
|
||||||
|
else:
|
||||||
|
# Parameter belongs to current bucket
|
||||||
|
if getattr(param, "ds_grad_is_ready", True):
|
||||||
|
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||||
|
|
||||||
|
# Check if bucket is complete
|
||||||
|
if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params):
|
||||||
|
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||||
|
|
||||||
|
# Process buffered parameters
|
||||||
|
while self.params_in_ipg_bucket_buffer:
|
||||||
|
buffered_param = self.params_in_ipg_bucket_buffer.pop(0)
|
||||||
|
ci, _, _ = self.grad_position[self.get_param_id(buffered_param)]
|
||||||
|
self._cur_bucket_index = ci
|
||||||
|
if getattr(buffered_param, "ds_grad_is_ready", True):
|
||||||
|
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param)
|
||||||
|
|
||||||
|
@instrument_w_nvtx
|
||||||
|
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
|
||||||
|
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||||
|
self._unflatten_partitioned_parameters(sub_group_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
|
||||||
|
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
||||||
|
self.fp32_partitioned_groups_flat[sub_group_id].data)
|
||||||
|
self._unflatten_partitioned_parameters(sub_group_id)
|
||||||
|
else:
|
||||||
|
self._partitioned_params_swap_out(sub_group_id)
|
||||||
|
|
||||||
|
@instrument_w_nvtx
|
||||||
|
def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param):
|
||||||
|
"""Asynchronously update partitioned parameters with optimized values."""
|
||||||
|
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True)
|
||||||
|
|
||||||
|
@instrument_w_nvtx
|
||||||
|
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
|
||||||
|
# print("[DEBUG] partition_grads called")
|
||||||
|
buffers = []
|
||||||
|
device_buffers = {}
|
||||||
|
buffer_numel_min = {}
|
||||||
|
buffer_numel_max = {}
|
||||||
|
|
||||||
|
for param, grad_partition in zip(params_to_release, grad_partitions):
|
||||||
|
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
|
||||||
|
|
||||||
|
if self.is_gradient_accumulation_boundary:
|
||||||
|
self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition)
|
||||||
|
|
||||||
|
buffer_numel = grad_partition.numel()
|
||||||
|
buffers.append(grad_partition)
|
||||||
|
|
||||||
|
if i not in device_buffers:
|
||||||
|
device_buffers[i] = []
|
||||||
|
device_buffers[i].append(grad_partition)
|
||||||
|
|
||||||
|
if i not in buffer_numel_min:
|
||||||
|
buffer_numel_min[i] = dest_offset
|
||||||
|
buffer_numel_max[i] = dest_offset + buffer_numel
|
||||||
|
else:
|
||||||
|
buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset)
|
||||||
|
buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel)
|
||||||
|
|
||||||
|
if self.is_gradient_accumulation_boundary:
|
||||||
|
for i in buffer_numel_min.keys():
|
||||||
|
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
|
||||||
|
0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i])
|
||||||
|
concatenated_buffer = torch.cat(device_buffers[i], dim=0).float()
|
||||||
|
|
||||||
|
if self.subgroup_to_device[i] == 'cpu':
|
||||||
|
# Trigger asynchronous CPU optimization
|
||||||
|
param_group_id = self.sub_group_to_group_id[i]
|
||||||
|
fp32_param = self.fp32_partitioned_groups_flat[i]
|
||||||
|
|
||||||
|
self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data,
|
||||||
|
concatenated_buffer.data)
|
||||||
|
self.async_cpuadam_num += 1
|
||||||
|
|
||||||
|
# Check for completed async operations
|
||||||
|
result = self.superoffload_cpu_optimizer.get_result()
|
||||||
|
if result is not None:
|
||||||
|
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||||
|
result[ResultKeys.UPDATED_PARAM])
|
||||||
|
self.async_cpuadam_num -= 1
|
||||||
|
|
||||||
|
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||||
|
else:
|
||||||
|
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||||
|
|
||||||
|
# Clean up parameter gradients
|
||||||
|
for param in params_to_release:
|
||||||
|
if not get_accelerator().is_synchronized_device():
|
||||||
|
param.grad.record_stream(get_accelerator().current_stream())
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
|
@instrument_w_nvtx
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""
|
||||||
|
Not supporting closure.
|
||||||
|
"""
|
||||||
|
# Wait for any pending asynchronous CPU optimizer operations
|
||||||
|
self._wait_for_async_operations()
|
||||||
|
|
||||||
|
self._pre_step()
|
||||||
|
self._partition_all_parameters()
|
||||||
|
|
||||||
|
if self._overflow_check_and_loss_scale_update():
|
||||||
|
self._handle_overflow_rollback()
|
||||||
|
return
|
||||||
|
|
||||||
|
norm_groups = self._get_norm_groups()
|
||||||
|
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
|
||||||
|
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
|
||||||
|
|
||||||
|
timer_names = set()
|
||||||
|
timer_names.add(OPTIMIZER_STEP_TIMER)
|
||||||
|
self.timers(OPTIMIZER_STEP_TIMER).start()
|
||||||
|
|
||||||
|
if self.check_clip_grads(scaled_global_grad_norm):
|
||||||
|
self._handle_gradient_clipping(scaled_global_grad_norm)
|
||||||
|
|
||||||
|
for sub_group_id, group in enumerate(self.fp16_groups):
|
||||||
|
# Prepare optimizer states, gradients and fp32 parameters for update
|
||||||
|
self._prepare_sub_group(sub_group_id, timer_names)
|
||||||
|
|
||||||
|
# Scale the fp32 gradients
|
||||||
|
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||||
|
|
||||||
|
# Apply the optimizer step on the sub group and copy fp32 parameters to fp16
|
||||||
|
self._optimizer_step(sub_group_id)
|
||||||
|
|
||||||
|
# Put fp16 parameters in appropriate location
|
||||||
|
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
|
||||||
|
|
||||||
|
# Release memory or swap out optimizer states of fp32 parameters
|
||||||
|
self._release_sub_group(sub_group_id, timer_names)
|
||||||
|
|
||||||
|
self.timers(OPTIMIZER_STEP_TIMER).stop()
|
||||||
|
self._post_step(timer_names)
|
||||||
|
|
||||||
|
def _wait_for_async_operations(self, timeout_seconds=60):
|
||||||
|
"""Wait for all pending asynchronous CPU optimizer operations to complete with timeout error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||||
|
"""
|
||||||
|
if self.async_cpuadam_num > 0:
|
||||||
|
logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...")
|
||||||
|
if self.async_cpuadam_num == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
initial_pending_ops = self.async_cpuadam_num
|
||||||
|
|
||||||
|
while self.async_cpuadam_num > 0:
|
||||||
|
result = self.superoffload_cpu_optimizer.get_result()
|
||||||
|
if result is None:
|
||||||
|
current_time = time.time()
|
||||||
|
elapsed_time = current_time - start_time
|
||||||
|
|
||||||
|
# Throw error if we've been waiting longer than the timeout
|
||||||
|
if elapsed_time >= timeout_seconds:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||||
|
f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. "
|
||||||
|
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||||
|
|
||||||
|
time.sleep(0.001) # 1ms sleep
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||||
|
result[ResultKeys.UPDATED_PARAM])
|
||||||
|
self.async_cpuadam_num -= 1
|
||||||
|
|
||||||
|
def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60):
|
||||||
|
"""Wait for a single asynchronous CPU-Adam optimizer operation with timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type (str): Type of operation expected ('adam_step' or 'rollback').
|
||||||
|
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type)
|
||||||
|
if result is not None:
|
||||||
|
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||||
|
result[ResultKeys.UPDATED_PARAM])
|
||||||
|
break
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
elapsed_time = current_time - start_time
|
||||||
|
|
||||||
|
# Throw error if we've been waiting longer than the timeout
|
||||||
|
if elapsed_time >= timeout_seconds:
|
||||||
|
raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||||
|
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||||
|
|
||||||
|
time.sleep(0.001) # 1ms sleep
|
||||||
|
|
||||||
|
def _sync_cpu_optimizer_step(self,
|
||||||
|
param_group_id: int,
|
||||||
|
sub_group_id: int,
|
||||||
|
fp32_param_data,
|
||||||
|
fp32_grad_data,
|
||||||
|
rollback: bool = False,
|
||||||
|
timeout_seconds: int = 60):
|
||||||
|
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||||
|
self.superoffload_cpu_optimizer.async_step(param_group_id,
|
||||||
|
sub_group_id,
|
||||||
|
fp32_param_data,
|
||||||
|
fp32_grad_data,
|
||||||
|
rollback=rollback)
|
||||||
|
# Wait for completion
|
||||||
|
self._wait_for_single_async_result(event_type, timeout_seconds)
|
||||||
|
|
||||||
|
def _handle_overflow_rollback(self):
|
||||||
|
"""Handle gradient overflow by rolling back CPU optimizer states."""
|
||||||
|
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||||
|
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||||
|
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||||
|
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||||
|
|
||||||
|
# Trigger rollback
|
||||||
|
self._sync_cpu_optimizer_step(param_group_id,
|
||||||
|
sub_group_id,
|
||||||
|
fp32_param.data,
|
||||||
|
fp32_param.grad.data,
|
||||||
|
rollback=True)
|
||||||
|
|
||||||
|
def _handle_gradient_clipping(self, scaled_global_grad_norm):
|
||||||
|
"""Handle gradient clipping with CPU optimizer rollback and re-optimization."""
|
||||||
|
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||||
|
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||||
|
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||||
|
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||||
|
|
||||||
|
# Rollback CPU optimizer states
|
||||||
|
self._sync_cpu_optimizer_step(param_group_id,
|
||||||
|
sub_group_id,
|
||||||
|
fp32_param.data,
|
||||||
|
fp32_param.grad.data,
|
||||||
|
rollback=True)
|
||||||
|
|
||||||
|
# Clip gradients and re-optimize
|
||||||
|
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||||
|
|
||||||
|
self._sync_cpu_optimizer_step(param_group_id,
|
||||||
|
sub_group_id,
|
||||||
|
fp32_param.data,
|
||||||
|
fp32_param.grad.data,
|
||||||
|
rollback=False)
|
||||||
|
|
||||||
|
@instrument_w_nvtx
|
||||||
|
def check_clip_grads(self, total_norm):
|
||||||
|
"""Check if gradients need to be clipped based on the global norm."""
|
||||||
|
unscaled_norm = total_norm / self.loss_scale
|
||||||
|
return self.clip_grad and unscaled_norm > self.clip_grad
|
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
# Copyright (c) DeepSpeed Team.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# DeepSpeed Team
|
||||||
|
"""
|
||||||
|
SuperOffload utilities for 1) running CPU optimizers in separate processes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
from deepspeed.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TaskKeys:
|
||||||
|
PARAM_DATA = "param_data"
|
||||||
|
PARAM_GRAD = "param_grad"
|
||||||
|
PARAM_GROUP_ID = "param_group_id"
|
||||||
|
SUB_GROUP_ID = "sub_group_id"
|
||||||
|
ROLLBACK = "rollback"
|
||||||
|
|
||||||
|
|
||||||
|
class ResultKeys:
|
||||||
|
UPDATED_PARAM = "updated_param"
|
||||||
|
EVENT_TYPE = "event_type"
|
||||||
|
|
||||||
|
|
||||||
|
class EventTypes:
|
||||||
|
ADAM_STEP = "adam_step"
|
||||||
|
ROLLBACK = "rollback"
|
||||||
|
|
||||||
|
|
||||||
|
def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue,
|
||||||
|
optimizer_config: Dict[str, Any], max_grad_numel: int) -> None:
|
||||||
|
"""
|
||||||
|
This function runs in a separate process and continuously processes optimization
|
||||||
|
tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and
|
||||||
|
applies optimization steps to parameters received from the main process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_queue: Queue for receiving optimization tasks
|
||||||
|
result_queue: Queue for sending back optimization results
|
||||||
|
optimizer_config: Configuration dictionary for the optimizer containing
|
||||||
|
lr, betas, eps, weight_decay, and amsgrad parameters
|
||||||
|
max_grad_numel: Maximum number of elements expected in gradient tensors
|
||||||
|
"""
|
||||||
|
# Initialize dummy parameter for optimizer creation
|
||||||
|
cpu_tensor = torch.randn(1, device="cpu")
|
||||||
|
cpu_param = torch.nn.Parameter(cpu_tensor)
|
||||||
|
|
||||||
|
try:
|
||||||
|
optimizer = DeepSpeedCPUAdam([cpu_param],
|
||||||
|
lr=optimizer_config["lr"],
|
||||||
|
betas=optimizer_config["betas"],
|
||||||
|
eps=optimizer_config["eps"],
|
||||||
|
weight_decay=optimizer_config["weight_decay"],
|
||||||
|
amsgrad=optimizer_config["amsgrad"])
|
||||||
|
except KeyError as e:
|
||||||
|
error_msg = f"Missing required optimizer config key: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
result_queue.put({"error": error_msg})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pre-allocate reusable pinned memory buffer for gradients
|
||||||
|
pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
task = param_queue.get()
|
||||||
|
|
||||||
|
if task is None:
|
||||||
|
logger.debug("Received termination signal, shutting down worker")
|
||||||
|
break
|
||||||
|
|
||||||
|
param_data = task[TaskKeys.PARAM_DATA]
|
||||||
|
param_grad = task[TaskKeys.PARAM_GRAD]
|
||||||
|
param_group_id = task[TaskKeys.PARAM_GROUP_ID]
|
||||||
|
sub_group_id = task[TaskKeys.SUB_GROUP_ID]
|
||||||
|
rollback = task.get(TaskKeys.ROLLBACK, False)
|
||||||
|
|
||||||
|
logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}")
|
||||||
|
|
||||||
|
del task[TaskKeys.PARAM_DATA]
|
||||||
|
del task[TaskKeys.PARAM_GRAD]
|
||||||
|
task.clear()
|
||||||
|
|
||||||
|
grad_numel = param_grad.numel()
|
||||||
|
if grad_numel > max_grad_numel:
|
||||||
|
error_msg = (
|
||||||
|
f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. "
|
||||||
|
f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.")
|
||||||
|
result_queue.put({"error": error_msg})
|
||||||
|
break
|
||||||
|
|
||||||
|
param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad)
|
||||||
|
param_grad_cpu.copy_(param_grad, non_blocking=True)
|
||||||
|
|
||||||
|
fp32_param = torch.nn.Parameter(param_data)
|
||||||
|
fp32_param.grad = param_grad_cpu
|
||||||
|
|
||||||
|
optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||||
|
|
||||||
|
if rollback:
|
||||||
|
logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}")
|
||||||
|
optimizer.rollback_subgroup(sub_group_id)
|
||||||
|
else:
|
||||||
|
optimizer.step_subgroup(sub_group_id)
|
||||||
|
|
||||||
|
# Send result back to main process
|
||||||
|
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||||
|
result_queue.put({
|
||||||
|
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||||
|
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||||
|
ResultKeys.UPDATED_PARAM: fp32_param.data,
|
||||||
|
ResultKeys.EVENT_TYPE: event_type,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Clean up references to free memory
|
||||||
|
optimizer.param_groups[param_group_id]['params'] = []
|
||||||
|
del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
error_msg = f"Missing required task key: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
result_queue.put({"error": error_msg})
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Unexpected error in worker process: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
result_queue.put({"error": error_msg})
|
||||||
|
break
|
||||||
|
|
||||||
|
# Clean up pinned memory buffer
|
||||||
|
if 'pinned_grad_buffer' in locals():
|
||||||
|
del pinned_grad_buffer
|
||||||
|
logger.debug("Cleaned up pinned memory buffer")
|
||||||
|
|
||||||
|
logger.debug("Worker process terminated")
|
||||||
|
|
||||||
|
|
||||||
|
class SuperOffloadCPUOptimizer:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer_config: Dict[str, Any],
|
||||||
|
cpuadam_cores_perc: float = 0.8,
|
||||||
|
max_grad_numel: int = 1000000) -> None:
|
||||||
|
if not 0 < cpuadam_cores_perc <= 1:
|
||||||
|
raise ValueError("cpuadam_cores_perc must be between 0 and 1")
|
||||||
|
|
||||||
|
self.max_grad_numel = max_grad_numel
|
||||||
|
self.mp_context = mp.get_context('spawn')
|
||||||
|
self.param_queue = self.mp_context.SimpleQueue()
|
||||||
|
self.result_queue = self.mp_context.SimpleQueue()
|
||||||
|
|
||||||
|
self.cpuadam_process = self.mp_context.Process(
|
||||||
|
target=superoffload_optimizer_worker,
|
||||||
|
args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self.cpuadam_process.start()
|
||||||
|
|
||||||
|
# Set CPU affinity for better performance isolation
|
||||||
|
self._set_cpu_affinity(cpuadam_cores_perc)
|
||||||
|
|
||||||
|
def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None:
|
||||||
|
"""
|
||||||
|
Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_process = psutil.Process()
|
||||||
|
all_cores = current_process.cpu_affinity()
|
||||||
|
num_cores = len(all_cores)
|
||||||
|
|
||||||
|
split_idx = int((1 - cpuadam_cores_perc) * num_cores)
|
||||||
|
pt_cores = all_cores[:split_idx]
|
||||||
|
cpuadam_cores = all_cores[split_idx:]
|
||||||
|
|
||||||
|
# Set affinity for main process (PyTorch)
|
||||||
|
current_process.cpu_affinity(pt_cores)
|
||||||
|
|
||||||
|
# Set affinity for optimizer process (CPU Adam)
|
||||||
|
optimizer_process = psutil.Process(self.cpuadam_process.pid)
|
||||||
|
optimizer_process.cpu_affinity(cpuadam_cores)
|
||||||
|
|
||||||
|
logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, "
|
||||||
|
f"Optimizer cores: {cpuadam_cores}")
|
||||||
|
|
||||||
|
except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e:
|
||||||
|
logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unexpected error setting CPU affinity: {e}")
|
||||||
|
|
||||||
|
def async_step(self,
|
||||||
|
param_group_id: int,
|
||||||
|
sub_group_id: int,
|
||||||
|
fp32_param: torch.Tensor,
|
||||||
|
fp32_grad: torch.Tensor,
|
||||||
|
rollback: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Queue parameter for optimization in the worker process.
|
||||||
|
"""
|
||||||
|
if not self.cpuadam_process.is_alive():
|
||||||
|
raise RuntimeError("Worker process is not alive")
|
||||||
|
|
||||||
|
self.param_queue.put({
|
||||||
|
TaskKeys.PARAM_DATA: fp32_param,
|
||||||
|
TaskKeys.PARAM_GRAD: fp32_grad,
|
||||||
|
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||||
|
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||||
|
TaskKeys.ROLLBACK: rollback,
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get result from worker process with optional event type validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback').
|
||||||
|
If provided, validates that the result matches.
|
||||||
|
"""
|
||||||
|
if self.result_queue.empty():
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = self.result_queue.get()
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
raise RuntimeError(f"Error in worker process: {result['error']}")
|
||||||
|
|
||||||
|
# Validate event type if expected_event_type is provided
|
||||||
|
if expected_event_type is not None:
|
||||||
|
result_event_type = result.get(ResultKeys.EVENT_TYPE)
|
||||||
|
if result_event_type != expected_event_type:
|
||||||
|
raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Shutdown the worker process gracefully.
|
||||||
|
|
||||||
|
Sends termination signal to worker and waits for clean shutdown.
|
||||||
|
If the process doesn't terminate within the timeout, it will be forcefully killed.
|
||||||
|
"""
|
||||||
|
if not self.cpuadam_process.is_alive():
|
||||||
|
logger.debug("Worker process already terminated")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send termination signal
|
||||||
|
self.param_queue.put(None)
|
||||||
|
|
||||||
|
# Wait for graceful shutdown
|
||||||
|
self.cpuadam_process.join(timeout=5)
|
||||||
|
|
||||||
|
if self.cpuadam_process.is_alive():
|
||||||
|
logger.warning("Optimizer process did not terminate cleanly within timeout, "
|
||||||
|
"forcefully terminating")
|
||||||
|
self.cpuadam_process.terminate()
|
||||||
|
self.cpuadam_process.join(timeout=2)
|
||||||
|
|
||||||
|
# Last resort: kill the process
|
||||||
|
if self.cpuadam_process.is_alive():
|
||||||
|
logger.error("Failed to terminate optimizer process, killing it")
|
||||||
|
self.cpuadam_process.kill()
|
||||||
|
self.cpuadam_process.join()
|
||||||
|
|
||||||
|
logger.debug("SuperOffload CPU optimizer closed successfully")
|
@ -93,6 +93,12 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
|||||||
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
||||||
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
||||||
|
|
||||||
|
super_offload: bool = False
|
||||||
|
""" Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3."""
|
||||||
|
|
||||||
|
cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0)
|
||||||
|
""" Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True."""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_pipeline(self):
|
def set_pipeline(self):
|
||||||
pipeline = self.pipeline_read or self.pipeline_write
|
pipeline = self.pipeline_read or self.pipeline_write
|
||||||
|
@ -178,6 +178,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
|||||||
zeropp_loco_param=None,
|
zeropp_loco_param=None,
|
||||||
log_trace_cache_warnings=False,
|
log_trace_cache_warnings=False,
|
||||||
enable_sanity_checks=False,
|
enable_sanity_checks=False,
|
||||||
|
cpuadam_cores_perc=0.8,
|
||||||
):
|
):
|
||||||
see_memory_usage("Stage 3 initialize beginning", force=True)
|
see_memory_usage("Stage 3 initialize beginning", force=True)
|
||||||
|
|
||||||
@ -873,7 +874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
|||||||
sub_group_size = len(self.fp16_partitioned_groups_flat)
|
sub_group_size = len(self.fp16_partitioned_groups_flat)
|
||||||
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
|
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
|
||||||
for i in range(sub_group_size):
|
for i in range(sub_group_size):
|
||||||
if i < int(self.partial_offload * sub_group_size):
|
if i >= int((1 - self.partial_offload) * sub_group_size):
|
||||||
self.subgroup_to_device[i] = 'cpu'
|
self.subgroup_to_device[i] = 'cpu'
|
||||||
else:
|
else:
|
||||||
self.subgroup_to_device[i] = get_accelerator()._name
|
self.subgroup_to_device[i] = get_accelerator()._name
|
||||||
|
@ -132,3 +132,183 @@ class TestCPUAdamGPUError(DistributedTest):
|
|||||||
param.grad = torch.randn(model_size, device=device)
|
param.grad = torch.randn(model_size, device=device)
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCPUAdamSubgroup(DistributedTest):
|
||||||
|
world_size = 1
|
||||||
|
reuse_dist_env = True
|
||||||
|
requires_cuda_env = False
|
||||||
|
if not get_accelerator().is_available():
|
||||||
|
init_distributed = False
|
||||||
|
set_dist_env = False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||||
|
@pytest.mark.parametrize('model_size', [64, 128, 1024])
|
||||||
|
def test_step_subgroup_basic(self, dtype, model_size):
|
||||||
|
"""Test basic functionality of step_subgroup method."""
|
||||||
|
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||||
|
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||||
|
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
# Create parameters
|
||||||
|
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
param = torch.nn.Parameter(cpu_data)
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
# Set gradient
|
||||||
|
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
|
||||||
|
# Store initial parameter values
|
||||||
|
initial_param = param.data.clone()
|
||||||
|
|
||||||
|
# Test step_subgroup with subgroup_id=0
|
||||||
|
subgroup_id = 0
|
||||||
|
optimizer.step_subgroup(subgroup_id)
|
||||||
|
|
||||||
|
# Verify parameter was updated
|
||||||
|
assert not torch.equal(param.data, initial_param), "Parameters should be updated after step_subgroup"
|
||||||
|
|
||||||
|
# Verify optimizer state was created for subgroup
|
||||||
|
assert subgroup_id in optimizer.state, "Optimizer state should be created for subgroup"
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == 1, "Step count should be 1"
|
||||||
|
assert 'exp_avg' in optimizer.state[subgroup_id], "exp_avg should be in state"
|
||||||
|
assert 'exp_avg_sq' in optimizer.state[subgroup_id], "exp_avg_sq should be in state"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||||
|
def test_step_subgroup_multiple_calls(self, dtype):
|
||||||
|
"""Test multiple calls to step_subgroup increment step count correctly."""
|
||||||
|
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||||
|
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||||
|
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
param = torch.nn.Parameter(cpu_data)
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
subgroup_id = 0
|
||||||
|
|
||||||
|
# Perform multiple steps
|
||||||
|
for step in range(1, 4):
|
||||||
|
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
optimizer.step_subgroup(subgroup_id)
|
||||||
|
|
||||||
|
# Verify step count increments
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == step, f"Step count should be {step}"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||||
|
def test_rollback_subgroup_basic(self, dtype):
|
||||||
|
"""Test basic functionality of rollback_subgroup method."""
|
||||||
|
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||||
|
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||||
|
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
param = torch.nn.Parameter(cpu_data)
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
subgroup_id = 0
|
||||||
|
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
|
||||||
|
# First, perform a step to initialize state
|
||||||
|
optimizer.step_subgroup(subgroup_id)
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == 1
|
||||||
|
|
||||||
|
# Store parameter state after step
|
||||||
|
param_after_step = param.data.clone()
|
||||||
|
exp_avg_after_step = optimizer.state[subgroup_id]['exp_avg'].clone()
|
||||||
|
exp_avg_sq_after_step = optimizer.state[subgroup_id]['exp_avg_sq'].clone()
|
||||||
|
|
||||||
|
# Now rollback
|
||||||
|
optimizer.rollback_subgroup(subgroup_id)
|
||||||
|
|
||||||
|
# Verify step count decremented
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == 0, "Step count should be decremented after rollback"
|
||||||
|
|
||||||
|
def test_rollback_subgroup_uninitialized_error(self):
|
||||||
|
"""Test that rollback_subgroup raises error for uninitialized subgroup."""
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
# Try to rollback uninitialized subgroup
|
||||||
|
with pytest.raises(RuntimeError, match="Cannot rollback optimizer state for sub_group_id 0"):
|
||||||
|
optimizer.rollback_subgroup(0)
|
||||||
|
|
||||||
|
def test_rollback_subgroup_zero_step_error(self):
|
||||||
|
"""Test that rollback_subgroup raises error when step count is already 0."""
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
subgroup_id = 0
|
||||||
|
param.grad = torch.randn(model_size, device='cpu')
|
||||||
|
|
||||||
|
# Initialize state by doing one step
|
||||||
|
optimizer.step_subgroup(subgroup_id)
|
||||||
|
|
||||||
|
# Rollback once (step should become 0)
|
||||||
|
optimizer.rollback_subgroup(subgroup_id)
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == 0
|
||||||
|
|
||||||
|
# Try to rollback again - should raise error
|
||||||
|
with pytest.raises(RuntimeError, match="Cannot rollback sub_group_id 0: step count is 0"):
|
||||||
|
optimizer.rollback_subgroup(subgroup_id)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||||
|
def test_step_rollback_sequence(self, dtype):
|
||||||
|
"""Test sequence of step_subgroup and rollback_subgroup operations."""
|
||||||
|
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||||
|
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||||
|
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
param = torch.nn.Parameter(cpu_data)
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
subgroup_id = 0
|
||||||
|
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||||
|
|
||||||
|
# Perform multiple steps
|
||||||
|
for step in range(1, 4):
|
||||||
|
optimizer.step_subgroup(subgroup_id)
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == step
|
||||||
|
|
||||||
|
# Rollback steps one by one
|
||||||
|
for step in range(2, -1, -1):
|
||||||
|
optimizer.rollback_subgroup(subgroup_id)
|
||||||
|
assert optimizer.state[subgroup_id]['step'] == step
|
||||||
|
|
||||||
|
def test_multiple_subgroups(self):
|
||||||
|
"""Test that different subgroups maintain independent state."""
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||||
|
|
||||||
|
model_size = 64
|
||||||
|
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||||
|
optimizer = DeepSpeedCPUAdam([param])
|
||||||
|
|
||||||
|
param.grad = torch.randn(model_size, device='cpu')
|
||||||
|
|
||||||
|
# Step different subgroups
|
||||||
|
optimizer.step_subgroup(0)
|
||||||
|
optimizer.step_subgroup(1)
|
||||||
|
optimizer.step_subgroup(0) # Step subgroup 0 again
|
||||||
|
|
||||||
|
# Verify independent step counts
|
||||||
|
assert optimizer.state[0]['step'] == 2, "Subgroup 0 should have step count 2"
|
||||||
|
assert optimizer.state[1]['step'] == 1, "Subgroup 1 should have step count 1"
|
||||||
|
|
||||||
|
# Rollback subgroup 0 only
|
||||||
|
optimizer.rollback_subgroup(0)
|
||||||
|
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
|
||||||
|
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"
|
||||||
|
Reference in New Issue
Block a user