mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
SuperOffload Release (#7559)
This PR introduces **SuperOffload**—an optimizer designed for Superchips (Nvidia GH200 & GB200, AMD MI300A) with high CPU–GPU bandwidth. It enables **full fine-tuning** of **GPT-OSS-20B, Qwen3-14B, and Phi-4** on a single GH200 GPU, achieving up to **~500 TFLOPS**, using Hugging Face Transformers and DeepSpeed—no custom modeling code required. SuperOffload extends ZeRO-Offload with fine-grained control and CPUAdam rollback utilities, allowing GPU execution to overlap with CPUAdam. This reduces GPU idle time and improves overall efficiency. Key changes: - New SuperOffloadOptimizer_Stage3 optimizer. - C++/CUDA binding for adam_rollback to revert one optimization step. - Config additions including super_offload and cpuadam_cores_perc. A detailed blog and tutorial will be available soon. --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -284,6 +284,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
|
||||
33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820)
|
||||
34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996)
|
||||
35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242)
|
||||
36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [ASPLOS 2026](https://www.asplos-conference.org/asplos2026)
|
||||
|
||||
# Videos
|
||||
1. DeepSpeed KDD 2020 Tutorial
|
||||
|
@ -8,6 +8,7 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||
m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||
}
|
||||
|
@ -236,6 +236,102 @@ int ds_adam_step(int optimizer_id,
|
||||
return 0;
|
||||
}
|
||||
|
||||
void adamw_rollback_inplace(float* params,
|
||||
const float* grads,
|
||||
float* momentum,
|
||||
float* variance,
|
||||
size_t param_size,
|
||||
float learning_rate,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float eps,
|
||||
float weight_decay,
|
||||
int& step_count)
|
||||
{
|
||||
const float lr = learning_rate;
|
||||
const float lambda = weight_decay;
|
||||
const float beta1_pow = std::pow(beta1, step_count);
|
||||
const float beta2_pow = std::pow(beta2, step_count);
|
||||
const float one_minus_beta1 = 1.0f - beta1;
|
||||
const float one_minus_beta2 = 1.0f - beta2;
|
||||
const float lr_lambda = lr * lambda;
|
||||
const float one_minus_lr_lambda = 1.0f - lr_lambda;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < param_size; ++i) {
|
||||
const float bias_correction1 = 1.0f - beta1_pow;
|
||||
const float bias_correction2 = 1.0f - beta2_pow;
|
||||
|
||||
const float m_hat = momentum[i] / bias_correction1;
|
||||
const float v_hat = variance[i] / bias_correction2;
|
||||
|
||||
const float denominator = std::sqrt(v_hat) + eps;
|
||||
|
||||
// Rollback parameter update
|
||||
const float update = lr * m_hat / denominator;
|
||||
float new_param = (params[i] + update) / one_minus_lr_lambda;
|
||||
|
||||
// Handle numerical instability
|
||||
if (!std::isfinite(new_param)) { new_param = 0.0f; }
|
||||
|
||||
params[i] = new_param;
|
||||
|
||||
const float grad = grads[i];
|
||||
momentum[i] = (momentum[i] - one_minus_beta1 * grad) / beta1;
|
||||
variance[i] = (variance[i] - one_minus_beta2 * grad * grad) / beta2;
|
||||
}
|
||||
|
||||
--step_count;
|
||||
}
|
||||
|
||||
int ds_adam_rollback(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq)
|
||||
{
|
||||
try {
|
||||
// Validate tensor types - rollback currently only supports float32
|
||||
if (params.scalar_type() != torch::kFloat32 || grads.scalar_type() != torch::kFloat32 ||
|
||||
exp_avg.scalar_type() != torch::kFloat32 ||
|
||||
exp_avg_sq.scalar_type() != torch::kFloat32) {
|
||||
printf("Error: Adam rollback currently only supports float32 tensors\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
float* params_ptr = params.data_ptr<float>();
|
||||
const float* grads_ptr = grads.data_ptr<float>();
|
||||
float* momentum_ptr = exp_avg.data_ptr<float>();
|
||||
float* variance_ptr = exp_avg_sq.data_ptr<float>();
|
||||
const size_t param_size = params.numel();
|
||||
int step_count = static_cast<int>(step);
|
||||
|
||||
adamw_rollback_inplace(params_ptr,
|
||||
grads_ptr,
|
||||
momentum_ptr,
|
||||
variance_ptr,
|
||||
param_size,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
epsilon,
|
||||
weight_decay,
|
||||
step_count);
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
printf("Error in Adam rollback for optimizer #%d: %s\n", optimizer_id, e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id)
|
||||
{
|
||||
s_optimizers.erase(optimizer_id);
|
||||
|
@ -217,4 +217,17 @@ int ds_adam_step(int optimizer_id,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq);
|
||||
|
||||
int ds_adam_rollback(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq);
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id);
|
||||
|
@ -164,3 +164,86 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def step_subgroup(self, subgroup_id: int, closure=None):
|
||||
"""Update the model parameters in a single subgroup (by index)."""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# Intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
|
||||
"sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
|
||||
|
||||
state = self.state[subgroup_id]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
|
||||
|
||||
state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
|
||||
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def rollback_subgroup(self, sub_group_id: int, closure=None):
|
||||
"""
|
||||
Rollback the optimizer state for a specific subgroup.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# Intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Validate subgroup state exists and is initialized
|
||||
if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0:
|
||||
raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} "
|
||||
f"as it has not been initialized.")
|
||||
|
||||
subgroup_state = self.state[sub_group_id]
|
||||
|
||||
# Check if we can rollback (step count must be > 0)
|
||||
if subgroup_state.get('step', 0) <= 0:
|
||||
raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: "
|
||||
f"step count is {subgroup_state.get('step', 0)}")
|
||||
|
||||
for _, group in enumerate(self.param_groups):
|
||||
for _, param in enumerate(group['params']):
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert param.device == device, (
|
||||
f"CPUAdam param is on {param.device} and must be 'cpu', "
|
||||
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")
|
||||
|
||||
# Decrement step count
|
||||
subgroup_state['step'] -= 1
|
||||
|
||||
# Extract hyperparameters
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
|
||||
group['eps'], group['weight_decay'], group['bias_correction'],
|
||||
param.data, param.grad.data, subgroup_state['exp_avg'],
|
||||
subgroup_state['exp_avg_sq'])
|
||||
return loss
|
||||
|
@ -886,6 +886,12 @@ class DeepSpeedEngine(Module):
|
||||
def zero_partial_offload(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
|
||||
|
||||
def super_offload(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "super_offload", False)
|
||||
|
||||
def cpuadam_cores_perc(self):
|
||||
return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9)
|
||||
|
||||
def zero_sub_group_size(self):
|
||||
return self._config.zero_config.sub_group_size
|
||||
|
||||
@ -1826,7 +1832,10 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
optimizer = DeepSpeedZeroOptimizer_Stage3(
|
||||
from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3
|
||||
Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload(
|
||||
) else SuperOffloadOptimizer_Stage3
|
||||
optimizer = Stage3ZeroOptimizer(
|
||||
self.module,
|
||||
optimizer,
|
||||
timers=timers,
|
||||
@ -1864,6 +1873,7 @@ class DeepSpeedEngine(Module):
|
||||
zeropp_loco_param=self.zeropp_loco_param(),
|
||||
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
|
||||
enable_sanity_checks=self.is_sanity_checks_enabled(),
|
||||
cpuadam_cores_perc=self.cpuadam_cores_perc(),
|
||||
)
|
||||
|
||||
else:
|
||||
|
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
412
deepspeed/runtime/superoffload/superoffload_stage3.py
Normal file
@ -0,0 +1,412 @@
|
||||
# Copyright (c) DeepSpeed Team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes
|
||||
from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
from deepspeed.utils.nvtx import instrument_w_nvtx
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
OPTIMIZER_STEP_TIMER = 'optimizer_step'
|
||||
|
||||
|
||||
class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
init_optimizer,
|
||||
timers,
|
||||
ds_config,
|
||||
static_loss_scale=1.0,
|
||||
dynamic_loss_scale=False,
|
||||
dynamic_loss_args=None,
|
||||
verbose=True,
|
||||
contiguous_gradients=True,
|
||||
reduce_bucket_size=500000000,
|
||||
prefetch_bucket_size=50000000,
|
||||
max_reuse_distance=1000000000,
|
||||
max_live_parameters=1000000000,
|
||||
param_persistence_threshold=100000,
|
||||
model_persistence_threshold=sys.maxsize,
|
||||
dp_process_group=None,
|
||||
reduce_scatter=True,
|
||||
overlap_comm=False,
|
||||
offload_optimizer_config=None,
|
||||
offload_param_config=None,
|
||||
sub_group_size=1000000000000,
|
||||
offload_ratio=0.0,
|
||||
mpu=None,
|
||||
clip_grad=0.0,
|
||||
gradient_accumulation_dtype=torch.float32,
|
||||
communication_data_type=torch.float16,
|
||||
postscale_gradients=True,
|
||||
gradient_predivide_factor=1.0,
|
||||
gradient_accumulation_steps=1,
|
||||
elastic_checkpoint=False,
|
||||
aio_config=None,
|
||||
all2all_process_group=None,
|
||||
zero_hpz_partition_size=1,
|
||||
zero_quantized_weights=False,
|
||||
zero_quantized_nontrainable_weights=False,
|
||||
zero_module_granularity_threshold=0,
|
||||
zeropp_loco_param=None,
|
||||
log_trace_cache_warnings=False,
|
||||
enable_sanity_checks=False,
|
||||
cpuadam_cores_perc=0.8,
|
||||
):
|
||||
|
||||
self.sub_group_to_param_num = {}
|
||||
self.params_in_ipg_bucket_buffer = []
|
||||
self._cur_bucket_index = -1
|
||||
self.async_cpuadam_num = 0
|
||||
self.max_grad_numel = 0
|
||||
|
||||
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
|
||||
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
|
||||
max_reuse_distance, max_live_parameters, param_persistence_threshold,
|
||||
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
|
||||
offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad,
|
||||
gradient_accumulation_dtype, communication_data_type, postscale_gradients,
|
||||
gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config,
|
||||
all2all_process_group, zero_hpz_partition_size, zero_quantized_weights,
|
||||
zero_quantized_nontrainable_weights, zero_module_granularity_threshold, zeropp_loco_param,
|
||||
log_trace_cache_warnings, enable_sanity_checks)
|
||||
|
||||
optimizer_config = {
|
||||
"lr": self.optimizer.param_groups[0]["lr"],
|
||||
"betas": self.optimizer.param_groups[0]["betas"],
|
||||
"eps": self.optimizer.param_groups[0]["eps"],
|
||||
"weight_decay": self.optimizer.param_groups[0]["weight_decay"],
|
||||
"amsgrad": self.optimizer.param_groups[0]["amsgrad"]
|
||||
}
|
||||
self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config,
|
||||
cpuadam_cores_perc=cpuadam_cores_perc,
|
||||
max_grad_numel=self.max_grad_numel)
|
||||
|
||||
def _create_fp16_sub_groups(self, params_group):
|
||||
|
||||
params_group_numel = sum([param.partition_numel() for param in params_group])
|
||||
sub_group_size = self.sub_group_size
|
||||
|
||||
if sub_group_size is None or sub_group_size >= params_group_numel:
|
||||
return [params_group]
|
||||
|
||||
sub_groups = []
|
||||
sub_group = []
|
||||
local_sub_group_size = 0
|
||||
|
||||
for param in params_group:
|
||||
sub_group.append(param)
|
||||
local_sub_group_size += param.partition_numel()
|
||||
|
||||
if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]):
|
||||
self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size)
|
||||
sub_groups.append(sub_group)
|
||||
self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group)
|
||||
|
||||
sub_group = []
|
||||
local_sub_group_size = 0
|
||||
|
||||
return sub_groups
|
||||
|
||||
def _optimizer_step(self, sub_group_id):
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
def step_with_gradscaler(optimizer):
|
||||
if self.torch_autocast_gradscaler:
|
||||
self.torch_autocast_gradscaler.step(optimizer)
|
||||
self.torch_autocast_gradscaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
cur_device = self.subgroup_to_device[sub_group_id]
|
||||
if cur_device != 'cpu':
|
||||
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||
step_with_gradscaler(self.backup_optimizer)
|
||||
self.backup_optimizer.param_groups[param_group_id]['params'] = []
|
||||
|
||||
def reduce_independent_p_g_buckets_and_remove_grads(self, param):
|
||||
comm_dtype = self.get_param_comm_dtype(param)
|
||||
bucket = self.ipg_buckets[comm_dtype]
|
||||
i, _, _ = self.grad_position[self.get_param_id(param)]
|
||||
|
||||
if len(bucket.params) == 0:
|
||||
self._cur_bucket_index = i
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||
|
||||
# If this is a single-parameter sub-group, reduce immediately
|
||||
if self.sub_group_to_param_num[self._cur_bucket_index] == 1:
|
||||
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||
|
||||
elif i != self._cur_bucket_index:
|
||||
# Parameter belongs to different sub-group, buffer it
|
||||
self.params_in_ipg_bucket_buffer.append(param)
|
||||
else:
|
||||
# Parameter belongs to current bucket
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param)
|
||||
|
||||
# Check if bucket is complete
|
||||
if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype)
|
||||
|
||||
# Process buffered parameters
|
||||
while self.params_in_ipg_bucket_buffer:
|
||||
buffered_param = self.params_in_ipg_bucket_buffer.pop(0)
|
||||
ci, _, _ = self.grad_position[self.get_param_id(buffered_param)]
|
||||
self._cur_bucket_index = ci
|
||||
if getattr(buffered_param, "ds_grad_is_ready", True):
|
||||
self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
self._unflatten_partitioned_parameters(sub_group_id)
|
||||
return
|
||||
|
||||
if self.fp16_partitioned_groups_flat[sub_group_id] is not None:
|
||||
self.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].data)
|
||||
self._unflatten_partitioned_parameters(sub_group_id)
|
||||
else:
|
||||
self._partitioned_params_swap_out(sub_group_id)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param):
|
||||
"""Asynchronously update partitioned parameters with optimized values."""
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
|
||||
# print("[DEBUG] partition_grads called")
|
||||
buffers = []
|
||||
device_buffers = {}
|
||||
buffer_numel_min = {}
|
||||
buffer_numel_max = {}
|
||||
|
||||
for param, grad_partition in zip(params_to_release, grad_partitions):
|
||||
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
|
||||
|
||||
if self.is_gradient_accumulation_boundary:
|
||||
self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition)
|
||||
|
||||
buffer_numel = grad_partition.numel()
|
||||
buffers.append(grad_partition)
|
||||
|
||||
if i not in device_buffers:
|
||||
device_buffers[i] = []
|
||||
device_buffers[i].append(grad_partition)
|
||||
|
||||
if i not in buffer_numel_min:
|
||||
buffer_numel_min[i] = dest_offset
|
||||
buffer_numel_max[i] = dest_offset + buffer_numel
|
||||
else:
|
||||
buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset)
|
||||
buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel)
|
||||
|
||||
if self.is_gradient_accumulation_boundary:
|
||||
for i in buffer_numel_min.keys():
|
||||
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
|
||||
0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i])
|
||||
concatenated_buffer = torch.cat(device_buffers[i], dim=0).float()
|
||||
|
||||
if self.subgroup_to_device[i] == 'cpu':
|
||||
# Trigger asynchronous CPU optimization
|
||||
param_group_id = self.sub_group_to_group_id[i]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[i]
|
||||
|
||||
self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data,
|
||||
concatenated_buffer.data)
|
||||
self.async_cpuadam_num += 1
|
||||
|
||||
# Check for completed async operations
|
||||
result = self.superoffload_cpu_optimizer.get_result()
|
||||
if result is not None:
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
self.async_cpuadam_num -= 1
|
||||
|
||||
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||
else:
|
||||
fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True)
|
||||
|
||||
# Clean up parameter gradients
|
||||
for param in params_to_release:
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
param.grad = None
|
||||
|
||||
@instrument_w_nvtx
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Not supporting closure.
|
||||
"""
|
||||
# Wait for any pending asynchronous CPU optimizer operations
|
||||
self._wait_for_async_operations()
|
||||
|
||||
self._pre_step()
|
||||
self._partition_all_parameters()
|
||||
|
||||
if self._overflow_check_and_loss_scale_update():
|
||||
self._handle_overflow_rollback()
|
||||
return
|
||||
|
||||
norm_groups = self._get_norm_groups()
|
||||
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
|
||||
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
|
||||
|
||||
timer_names = set()
|
||||
timer_names.add(OPTIMIZER_STEP_TIMER)
|
||||
self.timers(OPTIMIZER_STEP_TIMER).start()
|
||||
|
||||
if self.check_clip_grads(scaled_global_grad_norm):
|
||||
self._handle_gradient_clipping(scaled_global_grad_norm)
|
||||
|
||||
for sub_group_id, group in enumerate(self.fp16_groups):
|
||||
# Prepare optimizer states, gradients and fp32 parameters for update
|
||||
self._prepare_sub_group(sub_group_id, timer_names)
|
||||
|
||||
# Scale the fp32 gradients
|
||||
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||
|
||||
# Apply the optimizer step on the sub group and copy fp32 parameters to fp16
|
||||
self._optimizer_step(sub_group_id)
|
||||
|
||||
# Put fp16 parameters in appropriate location
|
||||
self._reassign_or_swap_out_partitioned_parameters(sub_group_id)
|
||||
|
||||
# Release memory or swap out optimizer states of fp32 parameters
|
||||
self._release_sub_group(sub_group_id, timer_names)
|
||||
|
||||
self.timers(OPTIMIZER_STEP_TIMER).stop()
|
||||
self._post_step(timer_names)
|
||||
|
||||
def _wait_for_async_operations(self, timeout_seconds=60):
|
||||
"""Wait for all pending asynchronous CPU optimizer operations to complete with timeout error.
|
||||
|
||||
Args:
|
||||
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||
"""
|
||||
if self.async_cpuadam_num > 0:
|
||||
logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...")
|
||||
if self.async_cpuadam_num == 0:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
initial_pending_ops = self.async_cpuadam_num
|
||||
|
||||
while self.async_cpuadam_num > 0:
|
||||
result = self.superoffload_cpu_optimizer.get_result()
|
||||
if result is None:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Throw error if we've been waiting longer than the timeout
|
||||
if elapsed_time >= timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||
f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. "
|
||||
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||
|
||||
time.sleep(0.001) # 1ms sleep
|
||||
continue
|
||||
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
self.async_cpuadam_num -= 1
|
||||
|
||||
def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60):
|
||||
"""Wait for a single asynchronous CPU-Adam optimizer operation with timeout.
|
||||
|
||||
Args:
|
||||
event_type (str): Type of operation expected ('adam_step' or 'rollback').
|
||||
timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type)
|
||||
if result is not None:
|
||||
self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID],
|
||||
result[ResultKeys.UPDATED_PARAM])
|
||||
break
|
||||
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Throw error if we've been waiting longer than the timeout
|
||||
if elapsed_time >= timeout_seconds:
|
||||
raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. "
|
||||
f"This indicates a deadlock or critical performance issue in the CPU optimizer.")
|
||||
|
||||
time.sleep(0.001) # 1ms sleep
|
||||
|
||||
def _sync_cpu_optimizer_step(self,
|
||||
param_group_id: int,
|
||||
sub_group_id: int,
|
||||
fp32_param_data,
|
||||
fp32_grad_data,
|
||||
rollback: bool = False,
|
||||
timeout_seconds: int = 60):
|
||||
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||
self.superoffload_cpu_optimizer.async_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param_data,
|
||||
fp32_grad_data,
|
||||
rollback=rollback)
|
||||
# Wait for completion
|
||||
self._wait_for_single_async_result(event_type, timeout_seconds)
|
||||
|
||||
def _handle_overflow_rollback(self):
|
||||
"""Handle gradient overflow by rolling back CPU optimizer states."""
|
||||
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
# Trigger rollback
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=True)
|
||||
|
||||
def _handle_gradient_clipping(self, scaled_global_grad_norm):
|
||||
"""Handle gradient clipping with CPU optimizer rollback and re-optimization."""
|
||||
for sub_group_id, _ in enumerate(self.fp16_groups):
|
||||
if self.subgroup_to_device[sub_group_id] == 'cpu':
|
||||
param_group_id = self.sub_group_to_group_id[sub_group_id]
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
|
||||
# Rollback CPU optimizer states
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=True)
|
||||
|
||||
# Clip gradients and re-optimize
|
||||
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
|
||||
|
||||
self._sync_cpu_optimizer_step(param_group_id,
|
||||
sub_group_id,
|
||||
fp32_param.data,
|
||||
fp32_param.grad.data,
|
||||
rollback=False)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def check_clip_grads(self, total_norm):
|
||||
"""Check if gradients need to be clipped based on the global norm."""
|
||||
unscaled_norm = total_norm / self.loss_scale
|
||||
return self.clip_grad and unscaled_norm > self.clip_grad
|
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
273
deepspeed/runtime/superoffload/superoffload_utils.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright (c) DeepSpeed Team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
SuperOffload utilities for 1) running CPU optimizers in separate processes.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Any
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import psutil
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.utils import logger
|
||||
|
||||
|
||||
class TaskKeys:
|
||||
PARAM_DATA = "param_data"
|
||||
PARAM_GRAD = "param_grad"
|
||||
PARAM_GROUP_ID = "param_group_id"
|
||||
SUB_GROUP_ID = "sub_group_id"
|
||||
ROLLBACK = "rollback"
|
||||
|
||||
|
||||
class ResultKeys:
|
||||
UPDATED_PARAM = "updated_param"
|
||||
EVENT_TYPE = "event_type"
|
||||
|
||||
|
||||
class EventTypes:
|
||||
ADAM_STEP = "adam_step"
|
||||
ROLLBACK = "rollback"
|
||||
|
||||
|
||||
def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue,
|
||||
optimizer_config: Dict[str, Any], max_grad_numel: int) -> None:
|
||||
"""
|
||||
This function runs in a separate process and continuously processes optimization
|
||||
tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and
|
||||
applies optimization steps to parameters received from the main process.
|
||||
|
||||
Args:
|
||||
param_queue: Queue for receiving optimization tasks
|
||||
result_queue: Queue for sending back optimization results
|
||||
optimizer_config: Configuration dictionary for the optimizer containing
|
||||
lr, betas, eps, weight_decay, and amsgrad parameters
|
||||
max_grad_numel: Maximum number of elements expected in gradient tensors
|
||||
"""
|
||||
# Initialize dummy parameter for optimizer creation
|
||||
cpu_tensor = torch.randn(1, device="cpu")
|
||||
cpu_param = torch.nn.Parameter(cpu_tensor)
|
||||
|
||||
try:
|
||||
optimizer = DeepSpeedCPUAdam([cpu_param],
|
||||
lr=optimizer_config["lr"],
|
||||
betas=optimizer_config["betas"],
|
||||
eps=optimizer_config["eps"],
|
||||
weight_decay=optimizer_config["weight_decay"],
|
||||
amsgrad=optimizer_config["amsgrad"])
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required optimizer config key: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
return
|
||||
|
||||
# Pre-allocate reusable pinned memory buffer for gradients
|
||||
pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
task = param_queue.get()
|
||||
|
||||
if task is None:
|
||||
logger.debug("Received termination signal, shutting down worker")
|
||||
break
|
||||
|
||||
param_data = task[TaskKeys.PARAM_DATA]
|
||||
param_grad = task[TaskKeys.PARAM_GRAD]
|
||||
param_group_id = task[TaskKeys.PARAM_GROUP_ID]
|
||||
sub_group_id = task[TaskKeys.SUB_GROUP_ID]
|
||||
rollback = task.get(TaskKeys.ROLLBACK, False)
|
||||
|
||||
logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}")
|
||||
|
||||
del task[TaskKeys.PARAM_DATA]
|
||||
del task[TaskKeys.PARAM_GRAD]
|
||||
task.clear()
|
||||
|
||||
grad_numel = param_grad.numel()
|
||||
if grad_numel > max_grad_numel:
|
||||
error_msg = (
|
||||
f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. "
|
||||
f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.")
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
|
||||
param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad)
|
||||
param_grad_cpu.copy_(param_grad, non_blocking=True)
|
||||
|
||||
fp32_param = torch.nn.Parameter(param_data)
|
||||
fp32_param.grad = param_grad_cpu
|
||||
|
||||
optimizer.param_groups[param_group_id]['params'] = [fp32_param]
|
||||
|
||||
if rollback:
|
||||
logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}")
|
||||
optimizer.rollback_subgroup(sub_group_id)
|
||||
else:
|
||||
optimizer.step_subgroup(sub_group_id)
|
||||
|
||||
# Send result back to main process
|
||||
event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP
|
||||
result_queue.put({
|
||||
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||
ResultKeys.UPDATED_PARAM: fp32_param.data,
|
||||
ResultKeys.EVENT_TYPE: event_type,
|
||||
})
|
||||
|
||||
# Clean up references to free memory
|
||||
optimizer.param_groups[param_group_id]['params'] = []
|
||||
del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required task key: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in worker process: {e}"
|
||||
logger.error(error_msg)
|
||||
result_queue.put({"error": error_msg})
|
||||
break
|
||||
|
||||
# Clean up pinned memory buffer
|
||||
if 'pinned_grad_buffer' in locals():
|
||||
del pinned_grad_buffer
|
||||
logger.debug("Cleaned up pinned memory buffer")
|
||||
|
||||
logger.debug("Worker process terminated")
|
||||
|
||||
|
||||
class SuperOffloadCPUOptimizer:
|
||||
|
||||
def __init__(self,
|
||||
optimizer_config: Dict[str, Any],
|
||||
cpuadam_cores_perc: float = 0.8,
|
||||
max_grad_numel: int = 1000000) -> None:
|
||||
if not 0 < cpuadam_cores_perc <= 1:
|
||||
raise ValueError("cpuadam_cores_perc must be between 0 and 1")
|
||||
|
||||
self.max_grad_numel = max_grad_numel
|
||||
self.mp_context = mp.get_context('spawn')
|
||||
self.param_queue = self.mp_context.SimpleQueue()
|
||||
self.result_queue = self.mp_context.SimpleQueue()
|
||||
|
||||
self.cpuadam_process = self.mp_context.Process(
|
||||
target=superoffload_optimizer_worker,
|
||||
args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel),
|
||||
daemon=True,
|
||||
)
|
||||
self.cpuadam_process.start()
|
||||
|
||||
# Set CPU affinity for better performance isolation
|
||||
self._set_cpu_affinity(cpuadam_cores_perc)
|
||||
|
||||
def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None:
|
||||
"""
|
||||
Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process.
|
||||
|
||||
Args:
|
||||
cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process
|
||||
"""
|
||||
try:
|
||||
current_process = psutil.Process()
|
||||
all_cores = current_process.cpu_affinity()
|
||||
num_cores = len(all_cores)
|
||||
|
||||
split_idx = int((1 - cpuadam_cores_perc) * num_cores)
|
||||
pt_cores = all_cores[:split_idx]
|
||||
cpuadam_cores = all_cores[split_idx:]
|
||||
|
||||
# Set affinity for main process (PyTorch)
|
||||
current_process.cpu_affinity(pt_cores)
|
||||
|
||||
# Set affinity for optimizer process (CPU Adam)
|
||||
optimizer_process = psutil.Process(self.cpuadam_process.pid)
|
||||
optimizer_process.cpu_affinity(cpuadam_cores)
|
||||
|
||||
logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, "
|
||||
f"Optimizer cores: {cpuadam_cores}")
|
||||
|
||||
except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e:
|
||||
logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error setting CPU affinity: {e}")
|
||||
|
||||
def async_step(self,
|
||||
param_group_id: int,
|
||||
sub_group_id: int,
|
||||
fp32_param: torch.Tensor,
|
||||
fp32_grad: torch.Tensor,
|
||||
rollback: bool = False) -> None:
|
||||
"""
|
||||
Queue parameter for optimization in the worker process.
|
||||
"""
|
||||
if not self.cpuadam_process.is_alive():
|
||||
raise RuntimeError("Worker process is not alive")
|
||||
|
||||
self.param_queue.put({
|
||||
TaskKeys.PARAM_DATA: fp32_param,
|
||||
TaskKeys.PARAM_GRAD: fp32_grad,
|
||||
TaskKeys.PARAM_GROUP_ID: param_group_id,
|
||||
TaskKeys.SUB_GROUP_ID: sub_group_id,
|
||||
TaskKeys.ROLLBACK: rollback,
|
||||
})
|
||||
|
||||
def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get result from worker process with optional event type validation.
|
||||
|
||||
Args:
|
||||
expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback').
|
||||
If provided, validates that the result matches.
|
||||
"""
|
||||
if self.result_queue.empty():
|
||||
return None
|
||||
|
||||
result = self.result_queue.get()
|
||||
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Error in worker process: {result['error']}")
|
||||
|
||||
# Validate event type if expected_event_type is provided
|
||||
if expected_event_type is not None:
|
||||
result_event_type = result.get(ResultKeys.EVENT_TYPE)
|
||||
if result_event_type != expected_event_type:
|
||||
raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'")
|
||||
|
||||
return result
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Shutdown the worker process gracefully.
|
||||
|
||||
Sends termination signal to worker and waits for clean shutdown.
|
||||
If the process doesn't terminate within the timeout, it will be forcefully killed.
|
||||
"""
|
||||
if not self.cpuadam_process.is_alive():
|
||||
logger.debug("Worker process already terminated")
|
||||
return
|
||||
|
||||
# Send termination signal
|
||||
self.param_queue.put(None)
|
||||
|
||||
# Wait for graceful shutdown
|
||||
self.cpuadam_process.join(timeout=5)
|
||||
|
||||
if self.cpuadam_process.is_alive():
|
||||
logger.warning("Optimizer process did not terminate cleanly within timeout, "
|
||||
"forcefully terminating")
|
||||
self.cpuadam_process.terminate()
|
||||
self.cpuadam_process.join(timeout=2)
|
||||
|
||||
# Last resort: kill the process
|
||||
if self.cpuadam_process.is_alive():
|
||||
logger.error("Failed to terminate optimizer process, killing it")
|
||||
self.cpuadam_process.kill()
|
||||
self.cpuadam_process.join()
|
||||
|
||||
logger.debug("SuperOffload CPU optimizer closed successfully")
|
@ -93,6 +93,12 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
|
||||
ratio: float = Field(1.0, ge=0.0, le=1.0)
|
||||
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
|
||||
|
||||
super_offload: bool = False
|
||||
""" Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3."""
|
||||
|
||||
cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0)
|
||||
""" Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_pipeline(self):
|
||||
pipeline = self.pipeline_read or self.pipeline_write
|
||||
|
@ -178,6 +178,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
zeropp_loco_param=None,
|
||||
log_trace_cache_warnings=False,
|
||||
enable_sanity_checks=False,
|
||||
cpuadam_cores_perc=0.8,
|
||||
):
|
||||
see_memory_usage("Stage 3 initialize beginning", force=True)
|
||||
|
||||
@ -873,7 +874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
sub_group_size = len(self.fp16_partitioned_groups_flat)
|
||||
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
|
||||
for i in range(sub_group_size):
|
||||
if i < int(self.partial_offload * sub_group_size):
|
||||
if i >= int((1 - self.partial_offload) * sub_group_size):
|
||||
self.subgroup_to_device[i] = 'cpu'
|
||||
else:
|
||||
self.subgroup_to_device[i] = get_accelerator()._name
|
||||
|
@ -132,3 +132,183 @@ class TestCPUAdamGPUError(DistributedTest):
|
||||
param.grad = torch.randn(model_size, device=device)
|
||||
with pytest.raises(AssertionError):
|
||||
optimizer.step()
|
||||
|
||||
|
||||
class TestCPUAdamSubgroup(DistributedTest):
|
||||
world_size = 1
|
||||
reuse_dist_env = True
|
||||
requires_cuda_env = False
|
||||
if not get_accelerator().is_available():
|
||||
init_distributed = False
|
||||
set_dist_env = False
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@pytest.mark.parametrize('model_size', [64, 128, 1024])
|
||||
def test_step_subgroup_basic(self, dtype, model_size):
|
||||
"""Test basic functionality of step_subgroup method."""
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
# Create parameters
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
param = torch.nn.Parameter(cpu_data)
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
# Set gradient
|
||||
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||
|
||||
# Store initial parameter values
|
||||
initial_param = param.data.clone()
|
||||
|
||||
# Test step_subgroup with subgroup_id=0
|
||||
subgroup_id = 0
|
||||
optimizer.step_subgroup(subgroup_id)
|
||||
|
||||
# Verify parameter was updated
|
||||
assert not torch.equal(param.data, initial_param), "Parameters should be updated after step_subgroup"
|
||||
|
||||
# Verify optimizer state was created for subgroup
|
||||
assert subgroup_id in optimizer.state, "Optimizer state should be created for subgroup"
|
||||
assert optimizer.state[subgroup_id]['step'] == 1, "Step count should be 1"
|
||||
assert 'exp_avg' in optimizer.state[subgroup_id], "exp_avg should be in state"
|
||||
assert 'exp_avg_sq' in optimizer.state[subgroup_id], "exp_avg_sq should be in state"
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_step_subgroup_multiple_calls(self, dtype):
|
||||
"""Test multiple calls to step_subgroup increment step count correctly."""
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
param = torch.nn.Parameter(cpu_data)
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
subgroup_id = 0
|
||||
|
||||
# Perform multiple steps
|
||||
for step in range(1, 4):
|
||||
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||
optimizer.step_subgroup(subgroup_id)
|
||||
|
||||
# Verify step count increments
|
||||
assert optimizer.state[subgroup_id]['step'] == step, f"Step count should be {step}"
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_rollback_subgroup_basic(self, dtype):
|
||||
"""Test basic functionality of rollback_subgroup method."""
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
param = torch.nn.Parameter(cpu_data)
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
subgroup_id = 0
|
||||
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||
|
||||
# First, perform a step to initialize state
|
||||
optimizer.step_subgroup(subgroup_id)
|
||||
assert optimizer.state[subgroup_id]['step'] == 1
|
||||
|
||||
# Store parameter state after step
|
||||
param_after_step = param.data.clone()
|
||||
exp_avg_after_step = optimizer.state[subgroup_id]['exp_avg'].clone()
|
||||
exp_avg_sq_after_step = optimizer.state[subgroup_id]['exp_avg_sq'].clone()
|
||||
|
||||
# Now rollback
|
||||
optimizer.rollback_subgroup(subgroup_id)
|
||||
|
||||
# Verify step count decremented
|
||||
assert optimizer.state[subgroup_id]['step'] == 0, "Step count should be decremented after rollback"
|
||||
|
||||
def test_rollback_subgroup_uninitialized_error(self):
|
||||
"""Test that rollback_subgroup raises error for uninitialized subgroup."""
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
# Try to rollback uninitialized subgroup
|
||||
with pytest.raises(RuntimeError, match="Cannot rollback optimizer state for sub_group_id 0"):
|
||||
optimizer.rollback_subgroup(0)
|
||||
|
||||
def test_rollback_subgroup_zero_step_error(self):
|
||||
"""Test that rollback_subgroup raises error when step count is already 0."""
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
subgroup_id = 0
|
||||
param.grad = torch.randn(model_size, device='cpu')
|
||||
|
||||
# Initialize state by doing one step
|
||||
optimizer.step_subgroup(subgroup_id)
|
||||
|
||||
# Rollback once (step should become 0)
|
||||
optimizer.rollback_subgroup(subgroup_id)
|
||||
assert optimizer.state[subgroup_id]['step'] == 0
|
||||
|
||||
# Try to rollback again - should raise error
|
||||
with pytest.raises(RuntimeError, match="Cannot rollback sub_group_id 0: step count is 0"):
|
||||
optimizer.rollback_subgroup(subgroup_id)
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_step_rollback_sequence(self, dtype):
|
||||
"""Test sequence of step_subgroup and rollback_subgroup operations."""
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
param = torch.nn.Parameter(cpu_data)
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
subgroup_id = 0
|
||||
param.grad = torch.randn(model_size, device='cpu').to(dtype)
|
||||
|
||||
# Perform multiple steps
|
||||
for step in range(1, 4):
|
||||
optimizer.step_subgroup(subgroup_id)
|
||||
assert optimizer.state[subgroup_id]['step'] == step
|
||||
|
||||
# Rollback steps one by one
|
||||
for step in range(2, -1, -1):
|
||||
optimizer.rollback_subgroup(subgroup_id)
|
||||
assert optimizer.state[subgroup_id]['step'] == step
|
||||
|
||||
def test_multiple_subgroups(self):
|
||||
"""Test that different subgroups maintain independent state."""
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
model_size = 64
|
||||
param = torch.nn.Parameter(torch.randn(model_size, device='cpu'))
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
|
||||
param.grad = torch.randn(model_size, device='cpu')
|
||||
|
||||
# Step different subgroups
|
||||
optimizer.step_subgroup(0)
|
||||
optimizer.step_subgroup(1)
|
||||
optimizer.step_subgroup(0) # Step subgroup 0 again
|
||||
|
||||
# Verify independent step counts
|
||||
assert optimizer.state[0]['step'] == 2, "Subgroup 0 should have step count 2"
|
||||
assert optimizer.state[1]['step'] == 1, "Subgroup 1 should have step count 1"
|
||||
|
||||
# Rollback subgroup 0 only
|
||||
optimizer.rollback_subgroup(0)
|
||||
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
|
||||
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"
|
||||
|
Reference in New Issue
Block a user