diff --git a/README.md b/README.md index b32cb4836..92a914a5b 100755 --- a/README.md +++ b/README.md @@ -284,6 +284,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information 33. Xinyu Lian, Sam Ade Jacobs, Lev Kurilenko, Masahiro Tanaka, Stas Bekman, Olatunji Ruwase, Minjia Zhang. (2024) Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training [arXiv:2406.18820](https://arxiv.org/abs/2406.18820) 34. Stas Bekman, Samyam Rajbhandari, Michael Wyatt, Jeff Rasley, Tunji Ruwase, Zhewei Yao, Aurick Qiao, Yuxiong He. (2025) Arctic Long Sequence Training: Scalable And Efficient Training For Multi-Million Token Sequences [arXiv:2506.13996](https://arxiv.org/abs/2506.13996) 35. Tingfeng Lan, Yusen Wu, Bin Ma, Zhaoyuan Su, Rui Yang, Tekin Bicer, Masahiro Tanaka, Olatunji Ruwase, Dong Li, Yue Cheng. (2025) ZenFlow: Enabling Stall-Free Offloading Training via Asynchronous Updates [arXiv:2505.12242](https://arxiv.org/abs/2505.12242) +36. Xinyu Lian, Masahiro Tanaka, Olatunji Ruwase, Minjia Zhang. (2026) SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips [ASPLOS 2026](https://www.asplos-conference.org/asplos2026) # Videos 1. DeepSpeed KDD 2020 Tutorial diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 263c443cb..f4c242ff9 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -8,6 +8,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("adam_rollback", &ds_adam_rollback, "DeepSpeed CPU Adam rollback (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); } diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp index 465aae7b9..1f2b8cf0d 100644 --- a/csrc/adam/cpu_adam_impl.cpp +++ b/csrc/adam/cpu_adam_impl.cpp @@ -236,6 +236,102 @@ int ds_adam_step(int optimizer_id, return 0; } +void adamw_rollback_inplace(float* params, + const float* grads, + float* momentum, + float* variance, + size_t param_size, + float learning_rate, + float beta1, + float beta2, + float eps, + float weight_decay, + int& step_count) +{ + const float lr = learning_rate; + const float lambda = weight_decay; + const float beta1_pow = std::pow(beta1, step_count); + const float beta2_pow = std::pow(beta2, step_count); + const float one_minus_beta1 = 1.0f - beta1; + const float one_minus_beta2 = 1.0f - beta2; + const float lr_lambda = lr * lambda; + const float one_minus_lr_lambda = 1.0f - lr_lambda; + +#pragma omp parallel for + for (size_t i = 0; i < param_size; ++i) { + const float bias_correction1 = 1.0f - beta1_pow; + const float bias_correction2 = 1.0f - beta2_pow; + + const float m_hat = momentum[i] / bias_correction1; + const float v_hat = variance[i] / bias_correction2; + + const float denominator = std::sqrt(v_hat) + eps; + + // Rollback parameter update + const float update = lr * m_hat / denominator; + float new_param = (params[i] + update) / one_minus_lr_lambda; + + // Handle numerical instability + if (!std::isfinite(new_param)) { new_param = 0.0f; } + + params[i] = new_param; + + const float grad = grads[i]; + momentum[i] = (momentum[i] - one_minus_beta1 * grad) / beta1; + variance[i] = (variance[i] - one_minus_beta2 * grad * grad) / beta2; + } + + --step_count; +} + +int ds_adam_rollback(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + try { + // Validate tensor types - rollback currently only supports float32 + if (params.scalar_type() != torch::kFloat32 || grads.scalar_type() != torch::kFloat32 || + exp_avg.scalar_type() != torch::kFloat32 || + exp_avg_sq.scalar_type() != torch::kFloat32) { + printf("Error: Adam rollback currently only supports float32 tensors\n"); + return -1; + } + + float* params_ptr = params.data_ptr(); + const float* grads_ptr = grads.data_ptr(); + float* momentum_ptr = exp_avg.data_ptr(); + float* variance_ptr = exp_avg_sq.data_ptr(); + const size_t param_size = params.numel(); + int step_count = static_cast(step); + + adamw_rollback_inplace(params_ptr, + grads_ptr, + momentum_ptr, + variance_ptr, + param_size, + lr, + beta1, + beta2, + epsilon, + weight_decay, + step_count); + + return 0; + } catch (const std::exception& e) { + printf("Error in Adam rollback for optimizer #%d: %s\n", optimizer_id, e.what()); + return -1; + } +} + int destroy_adam_optimizer(int optimizer_id) { s_optimizers.erase(optimizer_id); diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index a7db6fda3..e4fae63ce 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -217,4 +217,17 @@ int ds_adam_step(int optimizer_id, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq); +int ds_adam_rollback(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq); + int destroy_adam_optimizer(int optimizer_id); diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index e0a72a494..d0974497b 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -164,3 +164,86 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq']) return loss + + @torch.no_grad() + def step_subgroup(self, subgroup_id: int, closure=None): + """Update the model parameters in a single subgroup (by index).""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Intended device for step + device = torch.device('cpu') + + for group in self.param_groups: + for p in group['params']: + + if p.grad is None: + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config." + + state = self.state[subgroup_id] + + if len(state) == 0: + state['step'] = 0 + + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device) + + state['step'] += 1 + beta1, beta2 = group['betas'] + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq']) + return loss + + @torch.no_grad() + def rollback_subgroup(self, sub_group_id: int, closure=None): + """ + Rollback the optimizer state for a specific subgroup. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Intended device for step + device = torch.device('cpu') + + # Validate subgroup state exists and is initialized + if sub_group_id not in self.state or len(self.state[sub_group_id]) == 0: + raise RuntimeError(f"Cannot rollback optimizer state for sub_group_id {sub_group_id} " + f"as it has not been initialized.") + + subgroup_state = self.state[sub_group_id] + + # Check if we can rollback (step count must be > 0) + if subgroup_state.get('step', 0) <= 0: + raise RuntimeError(f"Cannot rollback sub_group_id {sub_group_id}: " + f"step count is {subgroup_state.get('step', 0)}") + + for _, group in enumerate(self.param_groups): + for _, param in enumerate(group['params']): + if param.grad is None: + continue + + assert param.device == device, ( + f"CPUAdam param is on {param.device} and must be 'cpu', " + f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.") + + # Decrement step count + subgroup_state['step'] -= 1 + + # Extract hyperparameters + beta1, beta2 = group['betas'] + + self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2, + group['eps'], group['weight_decay'], group['bias_correction'], + param.data, param.grad.data, subgroup_state['exp_avg'], + subgroup_state['exp_avg_sq']) + return loss diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 32c6a71cd..3d345adcb 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -886,6 +886,12 @@ class DeepSpeedEngine(Module): def zero_partial_offload(self): return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) + def super_offload(self): + return getattr(self._config.zero_config.offload_optimizer, "super_offload", False) + + def cpuadam_cores_perc(self): + return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9) + def zero_sub_group_size(self): return self._config.zero_config.sub_group_size @@ -1826,7 +1832,10 @@ class DeepSpeedEngine(Module): log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - optimizer = DeepSpeedZeroOptimizer_Stage3( + from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3 + Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload( + ) else SuperOffloadOptimizer_Stage3 + optimizer = Stage3ZeroOptimizer( self.module, optimizer, timers=timers, @@ -1864,6 +1873,7 @@ class DeepSpeedEngine(Module): zeropp_loco_param=self.zeropp_loco_param(), log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), enable_sanity_checks=self.is_sanity_checks_enabled(), + cpuadam_cores_perc=self.cpuadam_cores_perc(), ) else: diff --git a/deepspeed/runtime/superoffload/superoffload_stage3.py b/deepspeed/runtime/superoffload/superoffload_stage3.py new file mode 100644 index 000000000..c9a6c0478 --- /dev/null +++ b/deepspeed/runtime/superoffload/superoffload_stage3.py @@ -0,0 +1,412 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import sys +import time +import torch +from typing import List + +from deepspeed.runtime.superoffload.superoffload_utils import SuperOffloadCPUOptimizer, TaskKeys, ResultKeys, EventTypes +from deepspeed.runtime.zero.partition_parameters import Parameter, Tensor +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.utils.nvtx import instrument_w_nvtx +from deepspeed.utils import logger +from deepspeed.accelerator import get_accelerator + +OPTIMIZER_STEP_TIMER = 'optimizer_step' + + +class SuperOffloadOptimizer_Stage3(DeepSpeedZeroOptimizer_Stage3): + + def __init__( + self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + offload_ratio=0.0, + mpu=None, + clip_grad=0.0, + gradient_accumulation_dtype=torch.float32, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None, + all2all_process_group=None, + zero_hpz_partition_size=1, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + zero_module_granularity_threshold=0, + zeropp_loco_param=None, + log_trace_cache_warnings=False, + enable_sanity_checks=False, + cpuadam_cores_perc=0.8, + ): + + self.sub_group_to_param_num = {} + self.params_in_ipg_bucket_buffer = [] + self._cur_bucket_index = -1 + self.async_cpuadam_num = 0 + self.max_grad_numel = 0 + + super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale, + dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size, + max_reuse_distance, max_live_parameters, param_persistence_threshold, + model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm, + offload_optimizer_config, offload_param_config, sub_group_size, offload_ratio, mpu, clip_grad, + gradient_accumulation_dtype, communication_data_type, postscale_gradients, + gradient_predivide_factor, gradient_accumulation_steps, elastic_checkpoint, aio_config, + all2all_process_group, zero_hpz_partition_size, zero_quantized_weights, + zero_quantized_nontrainable_weights, zero_module_granularity_threshold, zeropp_loco_param, + log_trace_cache_warnings, enable_sanity_checks) + + optimizer_config = { + "lr": self.optimizer.param_groups[0]["lr"], + "betas": self.optimizer.param_groups[0]["betas"], + "eps": self.optimizer.param_groups[0]["eps"], + "weight_decay": self.optimizer.param_groups[0]["weight_decay"], + "amsgrad": self.optimizer.param_groups[0]["amsgrad"] + } + self.superoffload_cpu_optimizer = SuperOffloadCPUOptimizer(optimizer_config=optimizer_config, + cpuadam_cores_perc=cpuadam_cores_perc, + max_grad_numel=self.max_grad_numel) + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partition_numel() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + + for param in params_group: + sub_group.append(param) + local_sub_group_size += param.partition_numel() + + if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]): + self.max_grad_numel = max(self.max_grad_numel, local_sub_group_size) + sub_groups.append(sub_group) + self.sub_group_to_param_num[len(sub_groups) - 1] = len(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + def step_with_gradscaler(optimizer): + if self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.step(optimizer) + self.torch_autocast_gradscaler.update() + else: + optimizer.step() + + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device != 'cpu': + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + step_with_gradscaler(self.backup_optimizer) + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + + def reduce_independent_p_g_buckets_and_remove_grads(self, param): + comm_dtype = self.get_param_comm_dtype(param) + bucket = self.ipg_buckets[comm_dtype] + i, _, _ = self.grad_position[self.get_param_id(param)] + + if len(bucket.params) == 0: + self._cur_bucket_index = i + if getattr(param, "ds_grad_is_ready", True): + self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param) + + # If this is a single-parameter sub-group, reduce immediately + if self.sub_group_to_param_num[self._cur_bucket_index] == 1: + self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype) + + elif i != self._cur_bucket_index: + # Parameter belongs to different sub-group, buffer it + self.params_in_ipg_bucket_buffer.append(param) + else: + # Parameter belongs to current bucket + if getattr(param, "ds_grad_is_ready", True): + self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(param) + + # Check if bucket is complete + if self.sub_group_to_param_num[self._cur_bucket_index] == len(bucket.params): + self._DeepSpeedZeroOptimizer_Stage3__reduce_and_partition_ipg_grads(comm_dtype) + + # Process buffered parameters + while self.params_in_ipg_bucket_buffer: + buffered_param = self.params_in_ipg_bucket_buffer.pop(0) + ci, _, _ = self.grad_position[self.get_param_id(buffered_param)] + self._cur_bucket_index = ci + if getattr(buffered_param, "ds_grad_is_ready", True): + self._DeepSpeedZeroOptimizer_Stage3__add_grad_to_ipg_bucket(buffered_param) + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.subgroup_to_device[sub_group_id] == 'cpu': + self._unflatten_partitioned_parameters(sub_group_id) + return + + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters_async(self, sub_group_id, updated_param): + """Asynchronously update partitioned parameters with optimized values.""" + self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(updated_param, non_blocking=True) + + @instrument_w_nvtx + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + # print("[DEBUG] partition_grads called") + buffers = [] + device_buffers = {} + buffer_numel_min = {} + buffer_numel_max = {} + + for param, grad_partition in zip(params_to_release, grad_partitions): + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(grad_partition) + + buffer_numel = grad_partition.numel() + buffers.append(grad_partition) + + if i not in device_buffers: + device_buffers[i] = [] + device_buffers[i].append(grad_partition) + + if i not in buffer_numel_min: + buffer_numel_min[i] = dest_offset + buffer_numel_max[i] = dest_offset + buffer_numel + else: + buffer_numel_min[i] = min(buffer_numel_min[i], dest_offset) + buffer_numel_max[i] = max(buffer_numel_max[i], dest_offset + buffer_numel) + + if self.is_gradient_accumulation_boundary: + for i in buffer_numel_min.keys(): + fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( + 0, buffer_numel_min[i], buffer_numel_max[i] - buffer_numel_min[i]) + concatenated_buffer = torch.cat(device_buffers[i], dim=0).float() + + if self.subgroup_to_device[i] == 'cpu': + # Trigger asynchronous CPU optimization + param_group_id = self.sub_group_to_group_id[i] + fp32_param = self.fp32_partitioned_groups_flat[i] + + self.superoffload_cpu_optimizer.async_step(param_group_id, i, fp32_param.data, + concatenated_buffer.data) + self.async_cpuadam_num += 1 + + # Check for completed async operations + result = self.superoffload_cpu_optimizer.get_result() + if result is not None: + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + self.async_cpuadam_num -= 1 + + fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True) + else: + fp32_grad_tensor.copy_(concatenated_buffer, non_blocking=True) + + # Clean up parameter gradients + for param in params_to_release: + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) + param.grad = None + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + # Wait for any pending asynchronous CPU optimizer operations + self._wait_for_async_operations() + + self._pre_step() + self._partition_all_parameters() + + if self._overflow_check_and_loss_scale_update(): + self._handle_overflow_rollback() + return + + norm_groups = self._get_norm_groups() + scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups)) + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale + + timer_names = set() + timer_names.add(OPTIMIZER_STEP_TIMER) + self.timers(OPTIMIZER_STEP_TIMER).start() + + if self.check_clip_grads(scaled_global_grad_norm): + self._handle_gradient_clipping(scaled_global_grad_norm) + + for sub_group_id, group in enumerate(self.fp16_groups): + # Prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + # Scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + + # Apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + # Put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + # Release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.timers(OPTIMIZER_STEP_TIMER).stop() + self._post_step(timer_names) + + def _wait_for_async_operations(self, timeout_seconds=60): + """Wait for all pending asynchronous CPU optimizer operations to complete with timeout error. + + Args: + timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds. + """ + if self.async_cpuadam_num > 0: + logger.info(f"[INFO] {self.async_cpuadam_num} asynchronous CPU optimizer operations pending...") + if self.async_cpuadam_num == 0: + return + + start_time = time.time() + initial_pending_ops = self.async_cpuadam_num + + while self.async_cpuadam_num > 0: + result = self.superoffload_cpu_optimizer.get_result() + if result is None: + current_time = time.time() + elapsed_time = current_time - start_time + + # Throw error if we've been waiting longer than the timeout + if elapsed_time >= timeout_seconds: + raise RuntimeError( + f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. " + f"Still waiting for {self.async_cpuadam_num}/{initial_pending_ops} async operations to complete. " + f"This indicates a deadlock or critical performance issue in the CPU optimizer.") + + time.sleep(0.001) # 1ms sleep + continue + + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + self.async_cpuadam_num -= 1 + + def _wait_for_single_async_result(self, event_type: str, timeout_seconds=60): + """Wait for a single asynchronous CPU-Adam optimizer operation with timeout. + + Args: + event_type (str): Type of operation expected ('adam_step' or 'rollback'). + timeout_seconds (int): Maximum time to wait before throwing an error. Default is 60 seconds. + """ + start_time = time.time() + + while True: + result = self.superoffload_cpu_optimizer.get_result(expected_event_type=event_type) + if result is not None: + self._reassign_or_swap_out_partitioned_parameters_async(result[TaskKeys.SUB_GROUP_ID], + result[ResultKeys.UPDATED_PARAM]) + break + + current_time = time.time() + elapsed_time = current_time - start_time + + # Throw error if we've been waiting longer than the timeout + if elapsed_time >= timeout_seconds: + raise RuntimeError(f"SuperOffload CPU optimizer timeout after {elapsed_time:.1f} seconds. " + f"This indicates a deadlock or critical performance issue in the CPU optimizer.") + + time.sleep(0.001) # 1ms sleep + + def _sync_cpu_optimizer_step(self, + param_group_id: int, + sub_group_id: int, + fp32_param_data, + fp32_grad_data, + rollback: bool = False, + timeout_seconds: int = 60): + event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP + self.superoffload_cpu_optimizer.async_step(param_group_id, + sub_group_id, + fp32_param_data, + fp32_grad_data, + rollback=rollback) + # Wait for completion + self._wait_for_single_async_result(event_type, timeout_seconds) + + def _handle_overflow_rollback(self): + """Handle gradient overflow by rolling back CPU optimizer states.""" + for sub_group_id, _ in enumerate(self.fp16_groups): + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + # Trigger rollback + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=True) + + def _handle_gradient_clipping(self, scaled_global_grad_norm): + """Handle gradient clipping with CPU optimizer rollback and re-optimization.""" + for sub_group_id, _ in enumerate(self.fp16_groups): + if self.subgroup_to_device[sub_group_id] == 'cpu': + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + + # Rollback CPU optimizer states + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=True) + + # Clip gradients and re-optimize + self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) + + self._sync_cpu_optimizer_step(param_group_id, + sub_group_id, + fp32_param.data, + fp32_param.grad.data, + rollback=False) + + @instrument_w_nvtx + def check_clip_grads(self, total_norm): + """Check if gradients need to be clipped based on the global norm.""" + unscaled_norm = total_norm / self.loss_scale + return self.clip_grad and unscaled_norm > self.clip_grad diff --git a/deepspeed/runtime/superoffload/superoffload_utils.py b/deepspeed/runtime/superoffload/superoffload_utils.py new file mode 100644 index 000000000..e023730bd --- /dev/null +++ b/deepspeed/runtime/superoffload/superoffload_utils.py @@ -0,0 +1,273 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +""" +SuperOffload utilities for 1) running CPU optimizers in separate processes. + +""" + +from typing import Dict, Optional, Any +import torch +import torch.multiprocessing as mp +import psutil + +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.utils import logger + + +class TaskKeys: + PARAM_DATA = "param_data" + PARAM_GRAD = "param_grad" + PARAM_GROUP_ID = "param_group_id" + SUB_GROUP_ID = "sub_group_id" + ROLLBACK = "rollback" + + +class ResultKeys: + UPDATED_PARAM = "updated_param" + EVENT_TYPE = "event_type" + + +class EventTypes: + ADAM_STEP = "adam_step" + ROLLBACK = "rollback" + + +def superoffload_optimizer_worker(param_queue: mp.SimpleQueue, result_queue: mp.SimpleQueue, + optimizer_config: Dict[str, Any], max_grad_numel: int) -> None: + """ + This function runs in a separate process and continuously processes optimization + tasks from the parameter queue. It creates a DeepSpeedCPUAdam optimizer and + applies optimization steps to parameters received from the main process. + + Args: + param_queue: Queue for receiving optimization tasks + result_queue: Queue for sending back optimization results + optimizer_config: Configuration dictionary for the optimizer containing + lr, betas, eps, weight_decay, and amsgrad parameters + max_grad_numel: Maximum number of elements expected in gradient tensors + """ + # Initialize dummy parameter for optimizer creation + cpu_tensor = torch.randn(1, device="cpu") + cpu_param = torch.nn.Parameter(cpu_tensor) + + try: + optimizer = DeepSpeedCPUAdam([cpu_param], + lr=optimizer_config["lr"], + betas=optimizer_config["betas"], + eps=optimizer_config["eps"], + weight_decay=optimizer_config["weight_decay"], + amsgrad=optimizer_config["amsgrad"]) + except KeyError as e: + error_msg = f"Missing required optimizer config key: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + return + + # Pre-allocate reusable pinned memory buffer for gradients + pinned_grad_buffer = torch.empty(max_grad_numel, dtype=torch.float32, device='cpu', pin_memory=True) + + while True: + try: + task = param_queue.get() + + if task is None: + logger.debug("Received termination signal, shutting down worker") + break + + param_data = task[TaskKeys.PARAM_DATA] + param_grad = task[TaskKeys.PARAM_GRAD] + param_group_id = task[TaskKeys.PARAM_GROUP_ID] + sub_group_id = task[TaskKeys.SUB_GROUP_ID] + rollback = task.get(TaskKeys.ROLLBACK, False) + + logger.debug(f"Processing param_group_id: {param_group_id}, sub_group_id: {sub_group_id}") + + del task[TaskKeys.PARAM_DATA] + del task[TaskKeys.PARAM_GRAD] + task.clear() + + grad_numel = param_grad.numel() + if grad_numel > max_grad_numel: + error_msg = ( + f"Gradient size {grad_numel} exceeds pre-allocated buffer size {max_grad_numel}. " + f"This indicates insufficient buffer allocation. Please increase max_grad_numel parameter.") + result_queue.put({"error": error_msg}) + break + + param_grad_cpu = pinned_grad_buffer[:grad_numel].view_as(param_grad) + param_grad_cpu.copy_(param_grad, non_blocking=True) + + fp32_param = torch.nn.Parameter(param_data) + fp32_param.grad = param_grad_cpu + + optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + if rollback: + logger.debug(f"Rolling back optimizer state for sub_group_id: {sub_group_id}") + optimizer.rollback_subgroup(sub_group_id) + else: + optimizer.step_subgroup(sub_group_id) + + # Send result back to main process + event_type = EventTypes.ROLLBACK if rollback else EventTypes.ADAM_STEP + result_queue.put({ + TaskKeys.PARAM_GROUP_ID: param_group_id, + TaskKeys.SUB_GROUP_ID: sub_group_id, + ResultKeys.UPDATED_PARAM: fp32_param.data, + ResultKeys.EVENT_TYPE: event_type, + }) + + # Clean up references to free memory + optimizer.param_groups[param_group_id]['params'] = [] + del param_grad_cpu, fp32_param.grad, fp32_param, param_grad, param_data + + except KeyError as e: + error_msg = f"Missing required task key: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + break + except Exception as e: + error_msg = f"Unexpected error in worker process: {e}" + logger.error(error_msg) + result_queue.put({"error": error_msg}) + break + + # Clean up pinned memory buffer + if 'pinned_grad_buffer' in locals(): + del pinned_grad_buffer + logger.debug("Cleaned up pinned memory buffer") + + logger.debug("Worker process terminated") + + +class SuperOffloadCPUOptimizer: + + def __init__(self, + optimizer_config: Dict[str, Any], + cpuadam_cores_perc: float = 0.8, + max_grad_numel: int = 1000000) -> None: + if not 0 < cpuadam_cores_perc <= 1: + raise ValueError("cpuadam_cores_perc must be between 0 and 1") + + self.max_grad_numel = max_grad_numel + self.mp_context = mp.get_context('spawn') + self.param_queue = self.mp_context.SimpleQueue() + self.result_queue = self.mp_context.SimpleQueue() + + self.cpuadam_process = self.mp_context.Process( + target=superoffload_optimizer_worker, + args=(self.param_queue, self.result_queue, optimizer_config, max_grad_numel), + daemon=True, + ) + self.cpuadam_process.start() + + # Set CPU affinity for better performance isolation + self._set_cpu_affinity(cpuadam_cores_perc) + + def _set_cpu_affinity(self, cpuadam_cores_perc: float) -> None: + """ + Set CPU affinity for the main (Pytorch) process and worker (CPU Adam) process. + + Args: + cpuadam_cores_perc: Percentage of cores to allocate to the worker (CPU Adam) process + """ + try: + current_process = psutil.Process() + all_cores = current_process.cpu_affinity() + num_cores = len(all_cores) + + split_idx = int((1 - cpuadam_cores_perc) * num_cores) + pt_cores = all_cores[:split_idx] + cpuadam_cores = all_cores[split_idx:] + + # Set affinity for main process (PyTorch) + current_process.cpu_affinity(pt_cores) + + # Set affinity for optimizer process (CPU Adam) + optimizer_process = psutil.Process(self.cpuadam_process.pid) + optimizer_process.cpu_affinity(cpuadam_cores) + + logger.debug(f"Set CPU affinity - PyTorch cores: {pt_cores}, " + f"Optimizer cores: {cpuadam_cores}") + + except (psutil.AccessDenied, psutil.NoSuchProcess, AttributeError) as e: + logger.debug(f"Could not set CPU affinities for superoffload optimizer process: {e}") + except Exception as e: + logger.warning(f"Unexpected error setting CPU affinity: {e}") + + def async_step(self, + param_group_id: int, + sub_group_id: int, + fp32_param: torch.Tensor, + fp32_grad: torch.Tensor, + rollback: bool = False) -> None: + """ + Queue parameter for optimization in the worker process. + """ + if not self.cpuadam_process.is_alive(): + raise RuntimeError("Worker process is not alive") + + self.param_queue.put({ + TaskKeys.PARAM_DATA: fp32_param, + TaskKeys.PARAM_GRAD: fp32_grad, + TaskKeys.PARAM_GROUP_ID: param_group_id, + TaskKeys.SUB_GROUP_ID: sub_group_id, + TaskKeys.ROLLBACK: rollback, + }) + + def get_result(self, expected_event_type: str = None) -> Optional[Dict[str, Any]]: + """ + Get result from worker process with optional event type validation. + + Args: + expected_event_type (str, optional): Expected event type ('adam_step' or 'rollback'). + If provided, validates that the result matches. + """ + if self.result_queue.empty(): + return None + + result = self.result_queue.get() + + if "error" in result: + raise RuntimeError(f"Error in worker process: {result['error']}") + + # Validate event type if expected_event_type is provided + if expected_event_type is not None: + result_event_type = result.get(ResultKeys.EVENT_TYPE) + if result_event_type != expected_event_type: + raise RuntimeError(f"Event type mismatch: expected '{expected_event_type}', got '{result_event_type}'") + + return result + + def close(self) -> None: + """ + Shutdown the worker process gracefully. + + Sends termination signal to worker and waits for clean shutdown. + If the process doesn't terminate within the timeout, it will be forcefully killed. + """ + if not self.cpuadam_process.is_alive(): + logger.debug("Worker process already terminated") + return + + # Send termination signal + self.param_queue.put(None) + + # Wait for graceful shutdown + self.cpuadam_process.join(timeout=5) + + if self.cpuadam_process.is_alive(): + logger.warning("Optimizer process did not terminate cleanly within timeout, " + "forcefully terminating") + self.cpuadam_process.terminate() + self.cpuadam_process.join(timeout=2) + + # Last resort: kill the process + if self.cpuadam_process.is_alive(): + logger.error("Failed to terminate optimizer process, killing it") + self.cpuadam_process.kill() + self.cpuadam_process.join() + + logger.debug("SuperOffload CPU optimizer closed successfully") diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index ca35d7a7d..ac88d3226 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -93,6 +93,12 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): ratio: float = Field(1.0, ge=0.0, le=1.0) """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" + super_offload: bool = False + """ Enable high performance CPU offloading for Superchips. Only valid with ZeRO Stage 3.""" + + cpuadam_cores_perc: float = Field(0.8, ge=0.0, le=1.0) + """ Percentage of CPU cores to use for CPU Adam. Only valid with ZeRO Stage 3 and super_offload=True.""" + @model_validator(mode="after") def set_pipeline(self): pipeline = self.pipeline_read or self.pipeline_write diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7e13ad9fc..8b51781df 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -178,6 +178,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): zeropp_loco_param=None, log_trace_cache_warnings=False, enable_sanity_checks=False, + cpuadam_cores_perc=0.8, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -873,7 +874,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): sub_group_size = len(self.fp16_partitioned_groups_flat) # print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n") for i in range(sub_group_size): - if i < int(self.partial_offload * sub_group_size): + if i >= int((1 - self.partial_offload) * sub_group_size): self.subgroup_to_device[i] = 'cpu' else: self.subgroup_to_device[i] = get_accelerator()._name diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 851485440..d83b1732e 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -132,3 +132,183 @@ class TestCPUAdamGPUError(DistributedTest): param.grad = torch.randn(model_size, device=device) with pytest.raises(AssertionError): optimizer.step() + + +class TestCPUAdamSubgroup(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + @pytest.mark.parametrize('model_size', [64, 128, 1024]) + def test_step_subgroup_basic(self, dtype, model_size): + """Test basic functionality of step_subgroup method.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + # Create parameters + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + # Set gradient + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # Store initial parameter values + initial_param = param.data.clone() + + # Test step_subgroup with subgroup_id=0 + subgroup_id = 0 + optimizer.step_subgroup(subgroup_id) + + # Verify parameter was updated + assert not torch.equal(param.data, initial_param), "Parameters should be updated after step_subgroup" + + # Verify optimizer state was created for subgroup + assert subgroup_id in optimizer.state, "Optimizer state should be created for subgroup" + assert optimizer.state[subgroup_id]['step'] == 1, "Step count should be 1" + assert 'exp_avg' in optimizer.state[subgroup_id], "exp_avg should be in state" + assert 'exp_avg_sq' in optimizer.state[subgroup_id], "exp_avg_sq should be in state" + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_step_subgroup_multiple_calls(self, dtype): + """Test multiple calls to step_subgroup increment step count correctly.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + + # Perform multiple steps + for step in range(1, 4): + param.grad = torch.randn(model_size, device='cpu').to(dtype) + optimizer.step_subgroup(subgroup_id) + + # Verify step count increments + assert optimizer.state[subgroup_id]['step'] == step, f"Step count should be {step}" + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_rollback_subgroup_basic(self, dtype): + """Test basic functionality of rollback_subgroup method.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # First, perform a step to initialize state + optimizer.step_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == 1 + + # Store parameter state after step + param_after_step = param.data.clone() + exp_avg_after_step = optimizer.state[subgroup_id]['exp_avg'].clone() + exp_avg_sq_after_step = optimizer.state[subgroup_id]['exp_avg_sq'].clone() + + # Now rollback + optimizer.rollback_subgroup(subgroup_id) + + # Verify step count decremented + assert optimizer.state[subgroup_id]['step'] == 0, "Step count should be decremented after rollback" + + def test_rollback_subgroup_uninitialized_error(self): + """Test that rollback_subgroup raises error for uninitialized subgroup.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + # Try to rollback uninitialized subgroup + with pytest.raises(RuntimeError, match="Cannot rollback optimizer state for sub_group_id 0"): + optimizer.rollback_subgroup(0) + + def test_rollback_subgroup_zero_step_error(self): + """Test that rollback_subgroup raises error when step count is already 0.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu') + + # Initialize state by doing one step + optimizer.step_subgroup(subgroup_id) + + # Rollback once (step should become 0) + optimizer.rollback_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == 0 + + # Try to rollback again - should raise error + with pytest.raises(RuntimeError, match="Cannot rollback sub_group_id 0: step count is 0"): + optimizer.rollback_subgroup(subgroup_id) + + @pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16], ids=["fp16", "bf16"]) + def test_step_rollback_sequence(self, dtype): + """Test sequence of step_subgroup and rollback_subgroup operations.""" + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + cpu_data = torch.randn(model_size, device='cpu').to(dtype) + param = torch.nn.Parameter(cpu_data) + optimizer = DeepSpeedCPUAdam([param]) + + subgroup_id = 0 + param.grad = torch.randn(model_size, device='cpu').to(dtype) + + # Perform multiple steps + for step in range(1, 4): + optimizer.step_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == step + + # Rollback steps one by one + for step in range(2, -1, -1): + optimizer.rollback_subgroup(subgroup_id) + assert optimizer.state[subgroup_id]['step'] == step + + def test_multiple_subgroups(self): + """Test that different subgroups maintain independent state.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 64 + param = torch.nn.Parameter(torch.randn(model_size, device='cpu')) + optimizer = DeepSpeedCPUAdam([param]) + + param.grad = torch.randn(model_size, device='cpu') + + # Step different subgroups + optimizer.step_subgroup(0) + optimizer.step_subgroup(1) + optimizer.step_subgroup(0) # Step subgroup 0 again + + # Verify independent step counts + assert optimizer.state[0]['step'] == 2, "Subgroup 0 should have step count 2" + assert optimizer.state[1]['step'] == 1, "Subgroup 1 should have step count 1" + + # Rollback subgroup 0 only + optimizer.rollback_subgroup(0) + assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented" + assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"