Add ZenFlow code for Stage 3 (#7516)

This PR completes the ZenFlow integration for DeepSpeed ZeRO Stage 3. 

Highlights:

- ZenFlowSelectiveAdamW_stage3: Optimizer with importance-aware
selective parameter updates for ZeRO Stage 3.
- ZenFlowZeroOptimizer_Stage3: Full Stage 3 optimizer integration with
partitioned parameters and CPU offload.
- Configurable via ZenFlowConfig, fully integrated with
DeepSpeedZeroConfig for Stage 3.
- Unit tests for Stage 3 cases ensuring correctness and compatibility.

Note: Intergration with ZeRO Stage 1&2 was introduced in #7391

---------

Signed-off-by: Yusen Wu <xrn4ub@virginia.edu>
Co-authored-by: Ma, Guokai <guokai.ma@intel.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Tingfeng Lan <erc8gx@virginia.edu>
This commit is contained in:
JoshWoo2003
2025-10-13 12:19:18 -04:00
committed by GitHub
parent b7cd78f096
commit 7cb1b88ec4
10 changed files with 1218 additions and 234 deletions

View File

@ -6,4 +6,4 @@
from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam
from .zenflow_cpu_adam import ZenFlowCPUAdam
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3

View File

@ -53,30 +53,20 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
if offload:
self.step = self._step_with_offload
self.temp_copy_param = self._temp_copy_param_with_offload
self.group_step = self._group_step_with_offload
self.bucket_size = bucket_size
else:
self.step = self._step_without_offload
self.temp_copy_param = self._temp_copy_param_without_offload
self.group_step = self._group_step_without_offload
@torch.no_grad()
def _temp_copy_param_with_offload(self, group_to_paramlist):
def temp_copy_param(self, group_to_paramlist):
for group_id, params in group_to_paramlist.items():
for param in params:
if hasattr(param, "selected_grad"):
temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len(
param.shape) != 1 else param.data.clone().detach()
param.temp_selected_param = temp_selected_param.cpu()
@torch.no_grad()
def _temp_copy_param_without_offload(self, group_to_paramlist):
for group_id, params in group_to_paramlist.items():
for param in params:
if hasattr(param, "selected_grad"):
param.temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len(
param.shape) != 1 else param.data.clone().detach()
if self.offload:
param.temp_selected_param = temp_selected_param.cpu()
else:
param.temp_selected_param = temp_selected_param
def copy_mv_from_cpu(self, params):
for param in params:
@ -167,6 +157,13 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
@torch.no_grad()
def _step_with_offload(self):
"""
Performs parameter updates in offload mode.
In this mode, group_step() calls adamw() on each pre-partitioned param bucket,
so memory can be released after each bucket update to reduce GPU overhead.
Without offload, adamw() is called directly for speed.
"""
for group_id, group in enumerate(self.param_groups):
params = group["params"]
@ -197,10 +194,127 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
flush_bucket()
@torch.no_grad()
def _group_step_without_offload(self, group_to_paramlist):
def group_step(self, group_to_paramlist):
for group_id, params in group_to_paramlist.items():
group = self.param_groups[group_id]
if self.offload:
self.copy_mv_from_cpu(params)
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
for param in params:
if hasattr(param, "selected_grad"):
is_2d = (len(param.shape) != 1)
selected_param = param.data[:, param.selected_indices] if is_2d else param.data
state = self.state.setdefault(param, {})
if len(state) == 0:
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)
if not self.offload:
state["exp_avg"] = torch.zeros_like(selected_param)
state["exp_avg_sq"] = torch.zeros_like(selected_param)
if self.offload:
exp_avg_t = param.exp_avg.view_as(selected_param)
exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param)
else:
exp_avg_t = state["exp_avg"]
exp_avg_sq_t = state["exp_avg_sq"]
params_with_grad.append(selected_param)
grads.append(param.selected_grad)
exp_avgs.append(exp_avg_t)
exp_avg_sqs.append(exp_avg_sq_t)
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=False,
)
for i, param in enumerate(params):
if hasattr(param, "selected_grad") and len(param.shape) != 1:
param.data[:, param.selected_indices] = params_with_grad[i]
if self.offload:
self.copy_mv_to_cpu(params)
for param in params:
param.selected_grad = None
class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW):
def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs):
super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs)
self.offload = offload
if offload:
self.step = self._step_with_offload
self.bucket_size = bucket_size
else:
self.step = self._step_without_offload
@torch.no_grad()
def temp_copy_param(self, paramlist):
for param in paramlist:
if hasattr(param, "selected_grad"):
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view(
param.complete_numel // num_row, num_row)
temp_selected_param = param_2d[param.selected_indices, :].clone().detach()
else:
temp_selected_param = param.ds_tensor.data.clone().detach()
if self.offload:
param.temp_selected_param = temp_selected_param.cpu()
else:
param.temp_selected_param = temp_selected_param
def clear_selected_mv(self):
print("Zenflow: clearing selective optimizer states...")
for group in self.param_groups:
for param in group['params']:
state = self.state.setdefault(param, {})
if len(state) == 0:
continue
if self.offload:
param.exp_avg_cpu_data.zero_()
param.exp_avg_sq_cpu_data.zero_()
else:
state["exp_avg"].zero_()
state["exp_avg_sq"].zero_()
@torch.no_grad()
def _step_without_offload(self):
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
@ -209,10 +323,18 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
for param in params:
for param in group["params"]:
if hasattr(param, "selected_grad"):
selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
selected_param = param_2d[param.selected_indices, :]
else:
selected_param = param.ds_tensor.data
if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None:
selected_param.copy_(param.temp_selected_param)
state = self.state.setdefault(param, {})
if len(state) == 0:
@ -229,7 +351,6 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
adamw(
params_with_grad,
grads,
@ -245,44 +366,91 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
eps=group["eps"],
maximize=False,
)
for i, param in enumerate(params):
for i, param in enumerate(group["params"]):
if hasattr(param, "selected_grad"):
if len(param.shape) != 1:
param.data[:, param.selected_indices] = params_with_grad[i]
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = params_with_grad[i]
for param in params:
param.selected_grad = None
for param in group["params"]:
if hasattr(param, "temp_selected_param"):
param.temp_selected_param = None
param.selected_grad = None
def copy_mv_from_cpu(self, params):
for param in params:
param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True)
param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True)
def copy_mv_to_cpu(self, params):
for param in params:
param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True)
param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True)
param.exp_avg = None
param.exp_avg_sq = None
@torch.no_grad()
def _group_step_with_offload(self, group_to_paramlist):
for group_id, params in group_to_paramlist.items():
def group_step(self, paramlist):
group_to_paramlist = {}
for param in paramlist:
group_id = param.group_id
if group_id not in group_to_paramlist:
group_to_paramlist[group_id] = []
group_to_paramlist[group_id].append(param)
for group_id in sorted(group_to_paramlist.keys()):
params = group_to_paramlist[group_id]
group = self.param_groups[group_id]
self.copy_mv_from_cpu(params)
if self.offload:
self.copy_mv_from_cpu(params)
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
for param in params:
if hasattr(param, "selected_grad"):
selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
selected_param = param_2d[param.selected_indices, :]
else:
selected_param = param.ds_tensor.data
state = self.state.setdefault(param, {})
if len(state) == 0:
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
if amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)
if not self.offload:
state["exp_avg"] = torch.zeros_like(selected_param)
state["exp_avg_sq"] = torch.zeros_like(selected_param)
if self.offload:
exp_avg_t = param.exp_avg.view_as(selected_param)
exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param)
else:
exp_avg_t = state["exp_avg"]
exp_avg_sq_t = state["exp_avg_sq"]
params_with_grad.append(selected_param)
grads.append(param.selected_grad)
exp_avgs.append(param.exp_avg.view_as(selected_param))
exp_avg_sqs.append(param.exp_avg_sq.view_as(selected_param))
exp_avgs.append(exp_avg_t)
exp_avg_sqs.append(exp_avg_sq_t)
if amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
@ -305,14 +473,64 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
for i, param in enumerate(params):
if hasattr(param, "selected_grad"):
if len(param.shape) != 1:
param.data[:, param.selected_indices] = params_with_grad[i]
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = params_with_grad[i]
self.copy_mv_to_cpu(params)
if self.offload:
self.copy_mv_to_cpu(params)
for param in params:
param.selected_grad = None
@torch.no_grad()
def _step_with_offload(self):
"""
Performs parameter updates in offload mode.
In this mode, group_step() calls adamw() on each pre-partitioned param bucket,
so memory can be released after each bucket update to reduce GPU overhead.
Without offload, adamw() is called directly for speed.
"""
for group_id, group in enumerate(self.param_groups):
params = group["params"]
bucket = []
bucket_numel = 0
def flush_bucket():
if not bucket:
return
for param in bucket:
if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None:
temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True)
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row != 1:
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
param.complete_numel).view(
param.complete_numel // num_row, num_row)
param_2d[param.selected_indices, :] = temp_selected_param
else:
param.ds_tensor.data.copy_(temp_selected_param)
param.temp_selected_param = None
self.group_step(bucket)
bucket.clear()
for param in params:
if hasattr(param, "selected_grad"):
bucket.append(param)
bucket_numel += param.numel()
if bucket_numel >= self.bucket_size:
flush_bucket()
bucket_numel = 0
flush_bucket()
def _single_tensor_adamw(
params: List[Tensor],

View File

@ -1868,6 +1868,7 @@ class DeepSpeedEngine(Module):
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
zenflow_config=self.zenflow_config(),
sub_group_size=self.zero_sub_group_size(),
offload_ratio=self.zero_partial_offload(),
mpu=self.mpu,

View File

@ -0,0 +1,641 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.zero.partition_parameters import *
import torch
import math
from deepspeed import comm as dist
from deepspeed.utils import logger
from deepspeed.ops.adam import ZenFlowSelectiveAdamW_stage3
from deepspeed.runtime.utils import see_memory_usage
from typing import List
from deepspeed.accelerator import get_accelerator
from typing import TYPE_CHECKING
from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process
if TYPE_CHECKING:
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state'
INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
def configure_zenflow(optimizer_z3, zenflow_config):
optimizer_z3.select_strategy = zenflow_config.select_strategy
if optimizer_z3.select_strategy == 'auto':
optimizer_z3.select_strategy = "epoch"
if isinstance(zenflow_config.select_interval, int):
raise Warning(
"If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten."
)
optimizer_z3.select_interval = 1
else:
if isinstance(zenflow_config.select_interval, str):
raise ValueError("If don't use auto select strategy, select_interval must be a number.")
optimizer_z3.select_interval = int(zenflow_config.select_interval)
if isinstance(zenflow_config.update_interval, str):
optimizer_z3.auto_update = True
optimizer_z3.update_interval = 0
else:
optimizer_z3.auto_update = False
optimizer_z3.update_interval = int(zenflow_config.update_interval)
if optimizer_z3.select_strategy == 'epoch':
if zenflow_config.steps_per_epoch is not None:
optimizer_z3.select_interval = optimizer_z3.select_interval * zenflow_config.steps_per_epoch
else:
optimizer_z3.select_interval = 0
if not optimizer_z3.auto_update and optimizer_z3.select_interval != 0 and optimizer_z3.select_interval < optimizer_z3.update_interval:
raise ValueError("Select interval must be greater or equal to update interval")
optimizer_z3.topk_ratio = zenflow_config.topk_ratio
optimizer_z3.param_id_grad_sum_buffer_offset = {}
optimizer_z3.zf_stage3 = True
if optimizer_z3.auto_update:
optimizer_z3.param_id_sum_buffer_offset = {}
optimizer_z3.auto_ratio = zenflow_config.auto_ratio
optimizer_z3.zenflow_need_update = [False, False]
optimizer_z3.zenflow_state = 0
optimizer_z3.num_need_update = 0
def _initialize_zenflow_stage3_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3",
module,
zenflow_config: dict = None):
optimizer_z3.zenflow = True if zenflow_config is not None else False
if not optimizer_z3.zenflow:
return
optimizer_z3.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
for p in module.parameters():
p.data = p.data.t().contiguous() if len(p.shape) != 1 else p.data
def _initialize_zenflow_stage3_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3",
zenflow_config: dict = None,
overlap_comm: bool = False):
if not optimizer_z3.zenflow:
return
optimizer_z3.micro_step = -1
optimizer_z3.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
optimizer_z3.offload_selective_optimizer = zenflow_config.offload
optimizer_z3.zenflow_overlap_step = zenflow_config.overlap_step
if optimizer_z3.offload_selective_optimizer:
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
if optimizer_z3.zenflow_overlap_step:
optimizer_z3.process_optimizer_established = False
optimizer_z3.first_update_round_after_warmup = True
optimizer_z3.initialize_optimizer_states = lambda: initialize_optimizer_states(optimizer_z3)
optimizer_z3.step = lambda closure=None: step(optimizer_z3, closure)
optimizer_z3.zenflow_cpu_optimizer_overlap_step = lambda now_state, scaled_global_grad_norm: zenflow_cpu_optimizer_overlap_step(
optimizer_z3, now_state, scaled_global_grad_norm)
optimizer_z3.wait_last_update_and_copy = lambda timer_names: wait_last_update_and_copy(
optimizer_z3, timer_names)
optimizer_z3.partition_grads = lambda params_to_release, grad_partitions: partition_grads(
optimizer_z3, params_to_release, grad_partitions)
optimizer_z3.get_overlap_step_state = lambda: get_overlap_step_state(optimizer_z3)
optimizer_z3.start_optimizer_process = lambda: start_optimizer_process(optimizer_z3)
optimizer_z3.unscale_and_clip_grads = lambda sub_group_id, total_norm, now_state: unscale_and_clip_grads(
optimizer_z3, sub_group_id, total_norm, now_state)
configure_zenflow(optimizer_z3, zenflow_config)
optimizer_z3.selective_optimizer = ZenFlowSelectiveAdamW_stage3([{
k: v
for k, v in group.items() if k != "params"
} | {
"params": group["params"]
} for group in optimizer_z3.optimizer.param_groups],
offload=optimizer_z3.offload_selective_optimizer)
optimizer_z3.num_total_param = sum(
sum(1 for param in group["params"] if len(param.ds_shape) != 1)
for group in optimizer_z3.optimizer.param_groups)
def zenflow_cpu_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
return optimizer_z3.optimizer.step(step_id=optimizer_z3.micro_step + 1)
def _sync_selective_optimizer_lr(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
for group_selected, group in zip(optimizer_z3.selective_optimizer.param_groups,
optimizer_z3.optimizer.param_groups):
group_selected["lr"] = group["lr"]
def selective_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
optimizer_z3.selective_optimizer.step()
def is_zenflow_select_boundary(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> bool:
return optimizer_z3.zenflow and (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) >= 0 and (
(optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) == 0 or
(optimizer_z3.select_interval != 0 and optimizer_z3.micro_step % optimizer_z3.select_interval == 0))
def update_selected_channels(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", params_to_update, grad_partitions):
src_rk = dist.get_rank(optimizer_z3.dp_process_group)
total_rk = dist.get_world_size(optimizer_z3.dp_process_group)
total_chunk_size = 0
param_local_offset = [0 for _ in range(total_rk)]
for param, grad_partition in zip(params_to_update, grad_partitions):
param_max_chunk_size = 0
param_rk_offset = 0
for rk in range(total_rk):
contains_real_data = param.partition_numel() * rk < param.ds_numel
if not contains_real_data:
param.grad = None
continue
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row == 1:
continue
partition_size = param.partition_numel()
start = partition_size * rk
end = min(start + partition_size, param.ds_numel)
start_idx = math.ceil(start / num_row)
end_idx = end // num_row
num_cols = end_idx - start_idx
if param.ds_id not in optimizer_z3.param_id_grad_sum_buffer_offset:
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = []
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id].append(
(param_local_offset[rk], num_cols, param_rk_offset))
param_max_chunk_size = max(param_max_chunk_size, num_cols)
param_rk_offset += num_cols
param_local_offset[rk] += num_cols
total_chunk_size += param_max_chunk_size
optimizer_z3.grad_sum_buffer = torch.zeros(total_chunk_size, dtype=optimizer_z3.dtype, device='cuda')
for param, grad_partition in zip(params_to_update, grad_partitions):
contains_real_data = param.partition_numel() * src_rk < param.ds_numel
if not contains_real_data:
# this grad partition is empty - don't need to do anything
param.grad = None
continue
#ds_shape is the transposed shape, it should not be same as param.shape
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row == 1:
continue
partition_size = param.partition_numel()
start = partition_size * src_rk
end = min(start + partition_size, param.ds_numel)
start_idx = math.ceil(start / num_row)
end_idx = end // num_row
num_elements = (end_idx - start_idx) * num_row
param.complete_column_offset = start_idx * num_row - start
param.complete_numel = (end_idx - start_idx) * num_row
sum_per_column = grad_partition.narrow(0, param.complete_column_offset, num_elements)
sum_per_column = sum_per_column.view(end_idx - start_idx, num_row)
sum_array = sum_per_column.abs().sum(dim=1)
offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk]
optimizer_z3.grad_sum_buffer.narrow(0, offset, length).copy_(sum_array)
gathered_chunks = [torch.zeros_like(optimizer_z3.grad_sum_buffer) for _ in range(total_rk)]
dist.all_gather(gathered_chunks, optimizer_z3.grad_sum_buffer, group=optimizer_z3.dp_process_group)
for param in params_to_update:
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row == 1:
continue
param_column_sum = []
for rk in range(total_rk):
offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][rk]
param_column_sum.append(gathered_chunks[rk].narrow(0, offset, length))
global_param_column_sum = torch.cat(param_column_sum, dim=0)
num_select = max(1, int(global_param_column_sum.numel() * optimizer_z3.topk_ratio))
_, global_topk_indices = torch.topk(global_param_column_sum, num_select, largest=True)
_, length, rk_offset = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk]
local_indices = [(idx.item() - rk_offset) for idx in global_topk_indices
if rk_offset <= idx < rk_offset + length]
param.selected_indices = torch.tensor(local_indices, device='cuda')
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = []
optimizer_z3.grad_sum_buffer = None
def _process_selected_fp32_groups_grad(optimizer_z3, params_to_update, grad_partitions):
if optimizer_z3.auto_update:
optimizer_z3.sum_buffer = torch.zeros(optimizer_z3.num_total_param, dtype=optimizer_z3.dtype, device='cuda')
optimizer_z3.critic_sum_buffer = torch.zeros(optimizer_z3.num_total_param,
dtype=optimizer_z3.dtype,
device='cuda')
curr_buffer_idx = 0
for param, grad_partition in zip(params_to_update, grad_partitions):
rk = dist.get_rank(optimizer_z3.dp_process_group)
contains_real_data = param.partition_numel() * rk < param.ds_numel
if not contains_real_data:
# this grad partition is empty - don't need to do anything
param.grad = None
continue
#ds_shape is the transposed shape, it should not be same as param.shape
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
if num_row == 1:
param.selected_grad = grad_partition.clone().detach()
else:
grad_2d = grad_partition.narrow(0, param.complete_column_offset,
param.complete_numel).view(param.complete_numel // num_row, num_row)
param.selected_grad = grad_2d[param.selected_indices, :].clone().detach()
if optimizer_z3.auto_update:
optimizer_z3.sum_buffer[curr_buffer_idx] = grad_partition.abs().sum()
optimizer_z3.critic_sum_buffer[curr_buffer_idx] = param.selected_grad.abs().sum()
curr_buffer_idx += 1
if optimizer_z3.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'):
buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=optimizer_z3.device)
param.exp_avg_cpu_data = get_accelerator().pin_memory(
buffer) if optimizer_z3.offload_optimizer_pin_memory else buffer
param.exp_avg_sq_cpu_data = get_accelerator().pin_memory(
buffer.clone()) if optimizer_z3.offload_optimizer_pin_memory else buffer.clone()
if optimizer_z3.auto_update:
total_rk = dist.get_world_size(optimizer_z3.dp_process_group)
sum_gather_list = [torch.zeros_like(optimizer_z3.sum_buffer) for _ in range(total_rk)]
critic_gather_list = [torch.zeros_like(optimizer_z3.critic_sum_buffer) for _ in range(total_rk)]
curr_buffer_idx = 0
dist.all_gather(sum_gather_list, optimizer_z3.sum_buffer, group=optimizer_z3.dp_process_group)
dist.all_gather(critic_gather_list, optimizer_z3.critic_sum_buffer, group=optimizer_z3.dp_process_group)
for param in params_to_update:
if len(param.ds_shape) == 1:
continue
if not hasattr(param, 'non_critic_sum'):
param.non_critic_sum = 0
if not hasattr(param, 'avg_critic_sum'):
param.avg_critic_sum = 0
grad_total_sum = sum(sum_gather_list[rk][curr_buffer_idx] for rk in range(total_rk))
grad_critic_sum = sum(critic_gather_list[rk][curr_buffer_idx] for rk in range(total_rk))
param.avg_critic_sum = (param.avg_critic_sum * (optimizer_z3.update_interval - 1) +
grad_critic_sum) / optimizer_z3.update_interval / (optimizer_z3.topk_ratio * 10)
param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - optimizer_z3.topk_ratio) * 10)
if param.non_critic_sum >= param.avg_critic_sum:
optimizer_z3.num_need_update += 1
if optimizer_z3.num_need_update >= int(optimizer_z3.auto_ratio * optimizer_z3.num_total_param):
optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = True
curr_buffer_idx += 1
if not optimizer_z3.is_gradient_accumulation_boundary:
optimizer_z3.selective_optimizer.group_step(params_to_update)
else:
optimizer_z3.selective_optimizer.temp_copy_param(params_to_update)
if optimizer_z3.auto_update:
optimizer_z3.sum_buffer = None
optimizer_z3.critic_sum_buffer = None
def sync_fp32_param_from_gpu(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
if optimizer_z3.micro_step == 0:
return
for fp16_partitions, fp32_partition in zip(optimizer_z3.fp16_partitioned_groups_flat,
optimizer_z3.fp32_partitioned_groups_flat):
fp32_partition.data.copy_(fp16_partitions.data)
def zenflow_backward_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
optimizer_z3.micro_step += 1
if optimizer_z3.auto_update:
optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = False
optimizer_z3.num_need_update = 0
if optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state ^ 1]:
optimizer_z3.update_interval = 0
for group in optimizer_z3.fp16_groups:
for p in group:
p.non_critic_sum = 0
optimizer_z3.update_interval += 1
if optimizer_z3.is_zenflow_select_boundary():
sync_fp32_param_from_gpu(optimizer_z3)
optimizer_z3.selective_optimizer.clear_selected_mv()
def zenflow_backward_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
optimizer_z3._partition_all_parameters()
def log_selective_optimizer_timers(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
pass
def initialize_optimizer_states(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
num_subgroups = len(optimizer_z3.fp16_groups)
largest_numel = max([sum([p.ds_numel for p in psg]) for psg in optimizer_z3.fp16_partitioned_groups])
gradient_dtype = optimizer_z3.fp32_partitioned_groups_flat[0].dtype
gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=optimizer_z3.device)
timer_names = set()
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
is_adagrad = isinstance(optimizer_z3.optimizer, torch.optim.Adagrad)
if optimizer_z3.swap_optimizer:
optimizer_z3.optimizer_swapper.init_timers()
timer_names.add(INIT_OPTIMIZER_TIMER)
optimizer_z3.timers(INIT_OPTIMIZER_TIMER).start()
for i, group in enumerate(optimizer_z3.fp16_groups):
swappable_optimizer_subgroup = optimizer_z3._swappable_optimizer_subgroup(i)
swappable_param_subgroup = optimizer_z3.fp16_partitioned_groups_flat[i] is None
num_elements = int(optimizer_z3.fp16_partitioned_groups_flat_numel[i])
see_memory_usage(
f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
force=False)
if swappable_optimizer_subgroup:
optimizer_z3._optimizer_states_and_gradient_swap_in(i, timer_names)
if optimizer_z3.offload_optimizer and not swappable_optimizer_subgroup:
subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=optimizer_z3.device)
if optimizer_z3.offload_optimizer_pin_memory:
subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer)
optimizer_z3.fp32_partitioned_groups_flat[i].grad = None
optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad = [
subgroup_gradient_buffer.to(optimizer_z3.subgroup_to_device[i]),
subgroup_gradient_buffer.clone().to(optimizer_z3.subgroup_to_device[i])
]
else:
optimizer_z3.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)
if swappable_param_subgroup:
optimizer_z3._partitioned_params_swap_out(i)
if swappable_optimizer_subgroup:
optimizer_z3._optimizer_states_and_gradient_swap_out(i, timer_names)
see_memory_usage(
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
force=False)
# Initialize the optimizer states with the flattened fp32 partition.
if is_adagrad:
optimizer_z3.optimizer = torch.optim.Adagrad(optimizer_z3.fp32_partitioned_groups_flat,
**optimizer_z3.optimizer.defaults)
optimizer_z3.timers(INIT_OPTIMIZER_TIMER).stop()
optimizer_z3.timers.log(timer_names)
if optimizer_z3.swap_optimizer:
optimizer_z3.optimizer_swapper.log_timers()
if not optimizer_z3.offload_optimizer:
for group in optimizer_z3.fp32_partitioned_groups_flat:
group.grad = None
# Reset steps
return
def get_overlap_step_state(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> int:
if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds:
return optimizer_z3.micro_step & 1
else:
if not optimizer_z3.auto_update:
return (optimizer_z3.micro_step // optimizer_z3.update_interval) & 1
else:
return optimizer_z3.zenflow_state
@instrument_w_nvtx
def partition_grads(optimizer_z3, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
offload_fp32_gradients = {}
offload_fp32_offsets = {}
buffers = []
for param, grad_partition in zip(params_to_release, grad_partitions):
contains_real_data = param.partition_numel() * dist.get_rank(optimizer_z3.dp_process_group) < param.ds_numel
if not contains_real_data:
# this grad partition is empty - don't need to do anything
param.grad = None
continue
# move or accumulate gradient partition to target buffer
param_id_to_grad_partition = getattr(optimizer_z3,
f"_{optimizer_z3.__class__.__name__}__param_id_to_grad_partition")
grad_buffer = param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel())
buffers.append(grad_buffer)
if optimizer_z3.micro_step_id == 0: # don't accumulate
grad_buffer.copy_(grad_partition, non_blocking=True)
# ensure grad buffer is a CUDA buffer to speed up the next few
# operations and so it can be used asynchronously
grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True)
elif get_accelerator().on_accelerator(grad_buffer):
grad_buffer.add_(grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(grad_buffer.shape))
else:
# if dst is CPU, copy first to src device, do the addition
# there, then move back to dst. adding directly to cpu is very slow
cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True)
cuda_grad_buffer.add_(
grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(cuda_grad_buffer.shape))
grad_buffer.copy_(cuda_grad_buffer, non_blocking=True)
# ensure grad buffer is a CUDA buffer to speed up the next few
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer
# offload the gradient partition if applicable
if optimizer_z3.offload_optimizer:
i, dest_offset, _ = optimizer_z3.grad_position[optimizer_z3.get_param_id(param)]
now_state = optimizer_z3.get_overlap_step_state()
if optimizer_z3.is_gradient_accumulation_boundary:
optimizer_z3.norm_for_param_grads[optimizer_z3.get_param_id(
param)] = optimizer_z3._constant_buffered_norm2(grad_buffer)
if optimizer_z3._swappable_optimizer_subgroup(i):
if not i in offload_fp32_gradients.keys():
offload_fp32_gradients[i] = []
offload_fp32_offsets[i] = []
offload_fp32_gradients[i].append(grad_buffer.float())
offload_fp32_offsets[i].append(dest_offset)
else:
fp32_grad_tensor = optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad[now_state].narrow(
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(grad_buffer.float())
# free the gradient
if not get_accelerator().is_synchronized_device():
if param.grad is not None:
param.grad.record_stream(get_accelerator().current_stream())
param.grad = None
if optimizer_z3.offload_optimizer and optimizer_z3.swap_optimizer:
for i in offload_fp32_gradients.keys():
optimizer_z3.optimizer_swapper.swap_out_gradients(parameter=optimizer_z3.fp32_partitioned_groups_flat[i],
gradient_offsets=offload_fp32_offsets[i],
gradient_tensors=offload_fp32_gradients[i])
return buffers
@instrument_w_nvtx
def unscale_and_clip_grads(self, sub_group_id, total_norm, now_state):
# compute combined scale factor for this group
combined_scale = self.loss_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale
self.fp32_partitioned_groups_flat[sub_group_id].overlap_grad[now_state].mul_(1. / combined_scale)
def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_grad_norm):
if not optimizer_z3.process_optimizer_established:
optimizer_z3.start_optimizer_process()
group_infos = []
for group_no, group in enumerate(optimizer_z3.fp16_groups):
optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state)
param_group_id = optimizer_z3.sub_group_to_group_id[group_no]
group_info = {
"lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"],
"betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"],
"eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"],
"weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"],
"bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"],
}
group_infos.append(group_info)
optimizer_z3.parent_conn.send({
"type": "step",
"now_state": now_state,
"micro_step": optimizer_z3.micro_step,
"group_infos": group_infos
})
def wait_last_update_and_copy(optimizer_z3, timer_names):
if not hasattr(optimizer_z3, 'parent_conn'):
return
if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup:
optimizer_z3.first_update_round_after_warmup = False
return
msg = optimizer_z3.parent_conn.recv()
assert msg["type"] == "done", "Optimizer process did not finish stepping correctly."
for sub_group_id, group in enumerate(optimizer_z3.fp16_groups):
if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None:
optimizer_z3.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
optimizer_z3.fp32_partitioned_groups_flat[sub_group_id].stale_param.data)
#unflatten fp16 parameter subgroup
optimizer_z3._unflatten_partitioned_parameters(sub_group_id)
else:
optimizer_z3._partitioned_params_swap_out(sub_group_id)
optimizer_z3._post_step(timer_names)
# warn user about caching allocator flushes
memory_stats = get_accelerator().memory_stats()
alloc_retries = memory_stats.get("num_alloc_retries")
if alloc_retries is None:
alloc_retries = 0
if alloc_retries > optimizer_z3.n_caching_allocator_flushes:
if dist.get_rank() == 0:
logger.warning(
"%d pytorch allocator cache flushes since last step. this happens "
"when there is high memory pressure and is detrimental to "
"performance. if this is happening frequently consider adjusting "
"settings to reduce memory consumption. If you are unable to "
"make the cache flushes go away consider adding "
"get_accelerator().empty_cache() calls in your training loop to ensure "
"that all ranks flush their caches at the same time",
alloc_retries - optimizer_z3.n_caching_allocator_flushes)
optimizer_z3.n_caching_allocator_flushes = alloc_retries
@instrument_w_nvtx
def step(optimizer_z3, closure=None):
"""
Not supporting closure.
"""
optimizer_z3._pre_step()
optimizer_z3._partition_all_parameters()
#checks for overflow, adjust the loss scale accordingly
if optimizer_z3._overflow_check_and_loss_scale_update():
if optimizer_z3.swap_optimizer:
optimizer_z3.optimizer_swapper.log_timers()
return
norm_groups = optimizer_z3._get_norm_groups()
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
# Stash unscaled gradient norm
optimizer_z3._global_grad_norm = scaled_global_grad_norm / optimizer_z3.loss_scale
if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds:
optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm)
timer_names = set()
timer_names.add(OPTIMIZER_STEP_TIMER)
optimizer_z3.wait_last_update_and_copy(timer_names)
if optimizer_z3.micro_step >= optimizer_z3.full_warm_up_rounds:
optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm)
return

View File

@ -3,14 +3,11 @@
# DeepSpeed Team
import os
import math
import psutil
import torch
from deepspeed import comm as dist
import torch.multiprocessing as mp
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process
from deepspeed.runtime.utils import (see_memory_usage)
from deepspeed.ops.adam import ZenFlowSelectiveAdamW
@ -97,6 +94,8 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
self.offload_selective_optimizer = zenflow_config.offload
self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
self.start_optimizer_process = lambda: start_optimizer_process(self)
self.zf_stage3 = False
if self.offload_selective_optimizer:
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
@ -636,64 +635,10 @@ class ZenFlowZeroOptimizerSequential(ZenFlowZeroOptimizer):
self.optimizer.step(step_id=self.micro_step + 1)
def disable_accelerator():
accelerator = get_accelerator()
accelerator.is_available = lambda: False
accelerator.device_count = lambda: 0
accelerator.current_device = lambda: -1
# Optionally mark it as initialized if needed
if hasattr(accelerator, "_initialized"):
accelerator._initialized = True
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
shared_stale_param_map, zf_affinity):
disable_accelerator()
current_process = psutil.Process()
current_process.cpu_affinity(zf_affinity)
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
from deepspeed.ops.adam import ZenFlowCPUAdam
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
pipe.send({"type": "ready"})
# TODO: replace this with rpc
while True:
cmd = pipe.recv()
if cmd["type"] == "step":
now_state = cmd["now_state"]
micro_step = cmd["micro_step"]
group_infos = cmd["group_infos"]
for group_no, group_info in enumerate(group_infos):
original_param_groups = optimizer.param_groups
optimizer.param_groups = [original_param_groups[group_no]]
group = optimizer.param_groups[0]
for param_idx, param in enumerate(group["params"]):
key = (group_no, param_idx)
if key in shared_overlap_grad_map:
param.overlap_grad = shared_overlap_grad_map[key]
if key in shared_stale_param_map:
param.stale_param = shared_stale_param_map[key]
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
optimizer.param_groups = original_param_groups
pipe.send({"type": "done"})
elif cmd["type"] == "exit":
break
class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
def __init__(self, *args, **kwargs):
super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs)
self.process_pool = mp.Pool(1)
self.process_optimizer_established = False
self.first_update_round_after_warmup = True
@ -759,85 +704,6 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None #offload only
# check if all tensors in the list are equal to each other
def all_tensors_equal(self, tensor_list):
first_tensor = tensor_list[0]
for tensor in tensor_list[1:]:
if not torch.equal(first_tensor, tensor):
return False
return True
def start_optimizer_process(self):
from multiprocessing import Pipe, get_context, Manager
ctx = get_context("spawn")
self.parent_conn, self.child_conn = Pipe()
manager = Manager()
self.shared_overlap_grad_map = manager.dict()
self.shared_stale_param_map = manager.dict()
for group_no, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
param.data.share_memory_()
if not hasattr(param, 'stale_param'):
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
param.stale_param.data.share_memory_()
key = (group_no, param_idx)
self.shared_stale_param_map[key] = param.stale_param
if param.overlap_grad is not None:
param.overlap_grad[0].data.share_memory_()
param.overlap_grad[1].data.share_memory_()
key = (group_no, param_idx)
self.shared_overlap_grad_map[key] = param.overlap_grad
param_groups_data = self.optimizer.param_groups
curr_rank = dist.get_rank()
total_rank = dist.get_world_size()
current_process = psutil.Process()
current_affinity = current_process.cpu_affinity()
all_affinities = [
torch.zeros(len(current_affinity),
dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()) for _ in range(total_rank)
]
dist.all_gather(
all_affinities,
torch.tensor(current_affinity,
dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()))
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
if self.all_tensors_equal(all_affinities):
num_phy_cores = psutil.cpu_count(logical=False)
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
num_available_phy_cores = len(available_phy_cores)
my_rank = curr_rank
my_size = total_rank
cores_per_rank = num_available_phy_cores // my_size
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity))
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
zf_affinity = current_affinity[pt_num_cores:]
pt_affinity = current_affinity[:pt_num_cores]
else:
zf_affinity = current_affinity
pt_affinity = current_affinity
self.process = ctx.Process(
target=zenflow_optimizer_process,
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
self.shared_stale_param_map, zf_affinity),
)
self.process.daemon = True
self.process.start()
current_process.cpu_affinity(pt_affinity)
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
msg = self.parent_conn.recv()
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
self.process_optimizer_established = True
def wait_last_update_and_copy(self):
if not hasattr(self, 'parent_conn'):

View File

@ -3,7 +3,12 @@
# DeepSpeed Team
import os
import math
import torch
import psutil
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
def _flatten_dense_tensors(tensors):
@ -40,3 +45,147 @@ def _unflatten_dense_tensors(flat, tensors):
transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors]
unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors)
return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat]
def disable_accelerator():
accelerator = get_accelerator()
accelerator.is_available = lambda: False
accelerator.device_count = lambda: 0
accelerator.current_device = lambda: -1
# Optionally mark it as initialized if needed
if hasattr(accelerator, "_initialized"):
accelerator._initialized = True
def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity):
disable_accelerator()
current_process = psutil.Process()
current_process.cpu_affinity(zf_affinity)
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
from deepspeed.ops.adam import ZenFlowCPUAdam
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
pipe.send({"type": "ready"})
# TODO: replace this with rpc
while True:
cmd = pipe.recv()
if cmd["type"] == "step":
now_state = cmd["now_state"]
micro_step = cmd["micro_step"]
group_infos = cmd["group_infos"]
for group_no, group_info in enumerate(group_infos):
original_param_groups = optimizer.param_groups
optimizer.param_groups = [original_param_groups[group_no]]
group = optimizer.param_groups[0]
for param_idx, param in enumerate(group["params"]):
key = (group_no, param_idx)
if key in shared_overlap_grad_map:
param.overlap_grad = shared_overlap_grad_map[key]
if key in shared_stale_param_map:
param.stale_param = shared_stale_param_map[key]
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
optimizer.param_groups = original_param_groups
pipe.send({"type": "done"})
elif cmd["type"] == "exit":
break
def all_tensors_equal(tensor_list):
first_tensor = tensor_list[0]
for tensor in tensor_list[1:]:
if not torch.equal(first_tensor, tensor):
return False
return True
def start_optimizer_process(zf_optimizer):
from multiprocessing import Pipe, get_context, Manager
ctx = get_context("spawn")
zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe()
manager = Manager()
zf_optimizer.shared_overlap_grad_map = manager.dict()
zf_optimizer.shared_stale_param_map = manager.dict()
if zf_optimizer.zf_stage3:
params_iter = [((group_no, 0), param)
for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)]
else:
params_iter = [((group_no, param_idx), param)
for group_no, group in enumerate(zf_optimizer.optimizer.param_groups)
for param_idx, param in enumerate(group["params"])]
for key, param in params_iter:
param.data.share_memory_()
if not hasattr(param, "stale_param"):
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
param.stale_param.data.share_memory_()
zf_optimizer.shared_stale_param_map[key] = param.stale_param
if getattr(param, "overlap_grad", None) is not None:
param.overlap_grad[0].data.share_memory_()
param.overlap_grad[1].data.share_memory_()
zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad
param_groups_data = ([{
"params": [param]
} for param in zf_optimizer.fp32_partitioned_groups_flat]
if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups)
curr_rank = dist.get_rank()
total_rank = dist.get_world_size()
current_process = psutil.Process()
current_affinity = current_process.cpu_affinity()
all_affinities = [
torch.zeros(len(current_affinity),
dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()) for _ in range(total_rank)
]
dist.all_gather(
all_affinities,
torch.tensor(current_affinity, dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()))
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
if all_tensors_equal(all_affinities):
num_phy_cores = psutil.cpu_count(logical=False)
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
num_available_phy_cores = len(available_phy_cores)
my_rank = curr_rank
my_size = total_rank
cores_per_rank = num_available_phy_cores // my_size
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity))
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
zf_affinity = current_affinity[pt_num_cores:]
pt_affinity = current_affinity[:pt_num_cores]
else:
zf_affinity = current_affinity
pt_affinity = current_affinity
zf_optimizer.process = ctx.Process(
target=zenflow_optimizer_process,
args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map,
zf_optimizer.shared_stale_param_map, zf_affinity),
)
zf_optimizer.process.daemon = True
zf_optimizer.process.start()
current_process.cpu_affinity(pt_affinity)
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
msg = zf_optimizer.parent_conn.recv()
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
zf_optimizer.process_optimizer_established = True

View File

@ -93,6 +93,7 @@ class DeepSpeedZeRoOffload(object):
module,
timers,
ds_config,
zenflow=False,
overlap_comm=True,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
@ -115,6 +116,7 @@ class DeepSpeedZeRoOffload(object):
self.module = module
self.timers = timers
self.zenflow = zenflow
self.dtype = list(module.parameters())[0].dtype
self.dp_process_group = dp_process_group
self.offload_device = None
@ -472,6 +474,11 @@ class DeepSpeedZeRoOffload(object):
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=True)
if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
@torch.no_grad()
@ -480,6 +487,11 @@ class DeepSpeedZeRoOffload(object):
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)
if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module, forward=True)
@ -496,6 +508,11 @@ class DeepSpeedZeRoOffload(object):
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=False)
if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
@ -503,6 +520,11 @@ class DeepSpeedZeRoOffload(object):
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)
if self.zenflow:
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
self.get_param_coordinator().release_sub_module(sub_module, forward=False)
see_memory_usage(

View File

@ -26,6 +26,7 @@ from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
@ -160,6 +161,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
zenflow_config=None,
sub_group_size=1000000000000,
offload_ratio=0.0,
mpu=None,
@ -226,6 +228,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.partial_offload = offload_ratio
self.enable_sanity_checks = enable_sanity_checks
self.create_zenflow_hooks()
self._initialize_zenflow_stage3_prologue(module, zenflow_config)
#num of ranks in a ZeRO param partitioning group
self.zero_hpz_partition_size = zero_hpz_partition_size
@ -241,6 +246,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
module=module,
timers=timers,
ds_config=ds_config,
zenflow=self.zenflow,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
@ -276,6 +282,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
for i in range(1, len(self.optimizer.param_groups)):
self.backup_optimizer.add_param_group(self.optimizer.param_groups[i])
self._initialize_zenflow_stage3_epilogue(zenflow_config, overlap_comm)
self.module = module
self.elastic_checkpoint = elastic_checkpoint
@ -476,11 +484,32 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
print_rank_0("Removed grad acc hooks", force=False)
self.ipg_buckets.clear()
def create_zenflow_hooks(self):
from functools import partial
hook_names = [
"_initialize_zenflow_stage3_prologue",
"_initialize_zenflow_stage3_epilogue",
"zenflow_cpu_optimizer_step",
"_sync_selective_optimizer_lr",
"selective_optimizer_step",
"is_zenflow_select_boundary",
"update_selected_channels",
"_process_selected_fp32_groups_grad",
"zenflow_backward_prologue",
"zenflow_backward_epilogue",
"log_selective_optimizer_timers",
]
for name in hook_names:
fn = getattr(zf_engine_stage3, name)
setattr(self, name, partial(fn, self))
def initialize_ds_offload(
self,
module,
timers,
ds_config,
zenflow,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
@ -499,6 +528,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
ds_config=ds_config,
zenflow=zenflow,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
@ -738,6 +768,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.fp16_groups.append(sub_group)
self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group])
if self.zenflow:
for param in sub_group:
param.group_id = param_group_idx
# record sub group -> group mapping
self.sub_group_to_group_id[sub_group_idx] = param_group_idx
@ -997,7 +1031,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.torch_autocast_gradscaler.step(optimizer)
self.torch_autocast_gradscaler.update()
else:
optimizer.step()
if not self.zenflow:
optimizer.step()
else:
self.zenflow_cpu_optimizer_step()
if self.offload_optimizer:
cur_device = self.subgroup_to_device[sub_group_id]
@ -1267,8 +1304,14 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# move the gradient to a contiguous buffer
with get_accelerator().stream(self.reduce_and_partition_stream):
# move the parameter's gradient to the contiguous flat buffer
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
new_grad_tensor.copy_(param.grad, non_blocking=True)
if self.zenflow and len(param.ds_shape) != 1:
transposed_shape = param.grad.t().shape
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements,
param.grad.numel()).view(transposed_shape)
new_grad_tensor.copy_(param.grad.t().contiguous(), non_blocking=True)
else:
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
new_grad_tensor.copy_(param.grad, non_blocking=True)
if not get_accelerator().is_synchronized_device():
param.grad.record_stream(get_accelerator().current_stream())
param.grad.data = new_grad_tensor
@ -1308,6 +1351,12 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
params_in_bucket.sort(key=lambda p: p.ds_id)
grad_partitions = self.__avg_scatter_grads(params_in_bucket, communication_data_type)
if self.is_zenflow_select_boundary():
self.update_selected_channels(params_in_bucket, grad_partitions)
if self.zenflow and self.micro_step >= self.full_warm_up_rounds:
self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions)
self.partition_grads(params_in_bucket, grad_partitions)
params_in_bucket.clear()
@ -2272,6 +2321,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.swap_optimizer:
self.optimizer_swapper.pre_backward()
if self.zenflow:
self.zenflow_backward_prologue()
see_memory_usage("Before backward", force=False)
if self.custom_loss_scaler:
@ -2282,6 +2334,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
else:
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if self.zenflow:
self.zenflow_backward_epilogue()
if self.swap_optimizer:
self.optimizer_swapper.post_backward()

View File

@ -3,40 +3,64 @@
# DeepSpeed Team
import pytest
import torch
import numpy as np
from torch.nn import Parameter
from deepspeed.ops.adam import ZenFlowSelectiveAdamW
from deepspeed.ops.adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3
def make_param(shape, selected_indices=None):
def make_param(Opt, shape, selected_indices=None):
param = Parameter(torch.randn(*shape))
if Opt is ZenFlowSelectiveAdamW_stage3:
if param.dim() == 2:
param.ds_shape = (param.shape[1], param.shape[0])
param.ds_tensor = param.clone().T.contiguous().view(-1)
else:
param.ds_shape = tuple(param.shape)
param.ds_tensor = param.clone()
param.complete_column_offset = 0
param.complete_numel = param.numel()
param.group_id = 0
if selected_indices is not None:
param.selected_indices = selected_indices
param.selected_grad = torch.randn(param.shape[0], len(selected_indices))
param.temp_selected_param = param.data[:, selected_indices].clone()
if param.dim() == 2:
param.selected_grad = torch.randn(
param.shape[0], len(selected_indices)) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(
len(selected_indices), param.ds_shape[1])
param.temp_selected_param = param.data[:, selected_indices].clone(
) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
param.ds_shape)[selected_indices, :].clone()
else:
param.selected_grad = torch.randn_like(param.data)
param.temp_selected_param = param.data.clone()
return param
def test_init_methods():
opt1 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_init_methods(Opt):
opt1 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False)
assert opt1.step == opt1._step_without_offload
assert opt1.group_step == opt1._group_step_without_offload
opt2 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True)
opt2 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True)
assert opt2.step == opt2._step_with_offload
assert opt2.group_step == opt2._group_step_with_offload
def test_step_without_offload():
param = make_param((4, 6), torch.tensor([1, 3, 4]))
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_step_without_offload(Opt):
param = make_param(Opt, (4, 6), torch.tensor([1, 3, 4]))
param.requires_grad_(True)
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
old_selected = param.data[:, param.selected_indices].clone()
opt = Opt([param], lr=1e-3, offload=False)
old_selected = param.data[:, param.selected_indices].clone(
) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
param.ds_shape)[param.selected_indices, :].clone()
opt.step()
new_selected = param.data[:, param.selected_indices]
new_selected = param.data[:, param.
selected_indices] if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
param.ds_shape)[param.selected_indices, :]
diff_norm = (old_selected - new_selected).abs().sum().item()
assert diff_norm > 1e-5, "param was not updated"
@ -44,9 +68,10 @@ def test_step_without_offload():
assert param.selected_grad is None
def test_step_with_offload_bucket_flush():
param1 = make_param((2, 4), torch.tensor([1, 2]))
param2 = make_param((2, 4), torch.tensor([0, 3]))
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_step_with_offload_bucket_flush(Opt):
param1 = make_param(Opt, (2, 4), torch.tensor([1, 2]))
param2 = make_param(Opt, (2, 4), torch.tensor([0, 3]))
param1.exp_avg = torch.zeros_like(param1.temp_selected_param)
param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param)
@ -58,15 +83,16 @@ def test_step_with_offload_bucket_flush():
param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu()
param2.exp_avg_sq_cpu_data = param2.exp_avg_sq.clone().cpu()
opt = ZenFlowSelectiveAdamW([param1, param2], lr=1e-3, offload=True, bucket_size=1)
opt = Opt([param1, param2], lr=1e-3, offload=True, bucket_size=1)
opt.step()
assert param1.temp_selected_param is None
assert param2.temp_selected_param is None
def test_clear_selected_mv():
param = make_param((2, 4), torch.tensor([0, 2]))
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_clear_selected_mv(Opt):
param = make_param(Opt, (2, 4), torch.tensor([0, 2]))
opt = Opt([param], lr=1e-3, offload=False)
opt.step()
state = opt.state[param]
assert "exp_avg" in state
@ -74,17 +100,19 @@ def test_clear_selected_mv():
assert state["exp_avg"].abs().sum() == 0
def test_group_step_without_offload():
param = make_param((2, 6), torch.tensor([0, 1, 3]))
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
group_to_paramlist = {0: [param]}
opt._group_step_without_offload(group_to_paramlist)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_group_step_without_offload(Opt):
param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3]))
opt = Opt([param], lr=1e-3, offload=False)
group_to_paramlist = {0: [param]} if not Opt is ZenFlowSelectiveAdamW_stage3 else [param]
opt.group_step(group_to_paramlist)
assert param.selected_grad is None
def test_group_step_with_offload():
param = make_param((2, 6), torch.tensor([0, 1, 3]))
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=True)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_group_step_with_offload(Opt):
param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3]))
opt = Opt([param], lr=1e-3, offload=True)
state = opt.state.setdefault(param, {})
state["step"] = torch.zeros((), dtype=param.dtype, device=param.device)
@ -93,33 +121,30 @@ def test_group_step_with_offload():
param.exp_avg_cpu_data = param.exp_avg.clone().cpu()
param.exp_avg_sq_cpu_data = param.exp_avg_sq.clone().cpu()
group_to_paramlist = {0: [param]}
opt._group_step_with_offload(group_to_paramlist)
group_to_paramlist = {0: [param]} if Opt is not ZenFlowSelectiveAdamW_stage3 else [param]
opt.group_step(group_to_paramlist)
assert param.selected_grad is None
def test_1d_param_support():
param = Parameter(torch.randn(10))
param.selected_grad = torch.randn(10)
param.temp_selected_param = param.data.clone()
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_1d_param_support(Opt):
param = make_param(Opt, (10, ), torch.arange(10))
opt = Opt([param], lr=1e-3, offload=False)
opt.step()
assert param.temp_selected_param is None
assert param.selected_grad is None
def test_state_increment():
param = torch.nn.Parameter(torch.randn(2, 4))
param.selected_indices = torch.arange(4)
param.selected_grad = torch.randn(2, 4)
param.temp_selected_param = param.data.clone()
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_state_increment(Opt):
param = make_param(Opt, (2, 4), torch.arange(4))
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
opt = Opt([param], lr=1e-3, offload=False)
opt.step()
step1 = opt.state[param]['step'].item()
param.selected_grad = torch.randn(2, 4)
param.temp_selected_param = param.data.clone()
param.selected_grad = torch.randn(2, 4) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2)
param.temp_selected_param = param.data.clone() if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2)
param.selected_indices = torch.arange(4)
opt.step()
@ -134,22 +159,29 @@ def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4):
for _ in range(10):
grad = torch.randn_like(param)
param.selected_indices = torch.arange(param.shape[1])
param.selected_grad = grad
param.temp_selected_param = param.data.clone()
param.selected_grad = grad if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3) else grad.T
param.temp_selected_param = param.data.clone() if not isinstance(
zenflow_opt, ZenFlowSelectiveAdamW_stage3) else param.ds_tensor.view(param.ds_shape).clone()
torch_param.grad = grad.clone()
zenflow_opt.step()
torch_opt.step()
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
param.data.cpu().numpy(),
atol=atol,
err_msg="Mismatch with torch.AdamW")
if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3):
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
param.data.cpu().numpy(),
atol=atol,
err_msg="Mismatch with torch.AdamW")
else:
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
param.ds_tensor.view(param.ds_shape).T.clone().data.cpu().numpy(),
atol=atol,
err_msg="Mismatch with torch.AdamW")
def test_against_torch_adamw():
param = torch.nn.Parameter(torch.randn(2, 4))
param.selected_indices = torch.arange(4)
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
def test_against_torch_adamw(Opt):
param = make_param(Opt, (2, 4), torch.arange(4))
opt = Opt([param], lr=1e-3, offload=False)
_compare_with_torch_adamw(param, opt)

View File

@ -74,7 +74,7 @@ class BaseZenFlowTest:
model.destroy()
@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("stage", [1, 2, 3])
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [
@ -93,7 +93,7 @@ class TestZenFlowSingleGPU(DistributedTest, BaseZenFlowTest):
tester.run_training_distributed(config_dict)
@pytest.mark.parametrize("stage", [1, 2])
@pytest.mark.parametrize("stage", [1, 2, 3])
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [