mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 06:53:47 +08:00
Add ZenFlow code for Stage 3 (#7516)
This PR completes the ZenFlow integration for DeepSpeed ZeRO Stage 3. Highlights: - ZenFlowSelectiveAdamW_stage3: Optimizer with importance-aware selective parameter updates for ZeRO Stage 3. - ZenFlowZeroOptimizer_Stage3: Full Stage 3 optimizer integration with partitioned parameters and CPU offload. - Configurable via ZenFlowConfig, fully integrated with DeepSpeedZeroConfig for Stage 3. - Unit tests for Stage 3 cases ensuring correctness and compatibility. Note: Intergration with ZeRO Stage 1&2 was introduced in #7391 --------- Signed-off-by: Yusen Wu <xrn4ub@virginia.edu> Co-authored-by: Ma, Guokai <guokai.ma@intel.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Tingfeng Lan <erc8gx@virginia.edu>
This commit is contained in:
@ -6,4 +6,4 @@
|
||||
from .cpu_adam import DeepSpeedCPUAdam
|
||||
from .fused_adam import FusedAdam
|
||||
from .zenflow_cpu_adam import ZenFlowCPUAdam
|
||||
from .zenflow_torch_adam import ZenFlowSelectiveAdamW
|
||||
from .zenflow_torch_adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3
|
||||
|
@ -53,30 +53,20 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
|
||||
if offload:
|
||||
self.step = self._step_with_offload
|
||||
self.temp_copy_param = self._temp_copy_param_with_offload
|
||||
self.group_step = self._group_step_with_offload
|
||||
self.bucket_size = bucket_size
|
||||
else:
|
||||
self.step = self._step_without_offload
|
||||
self.temp_copy_param = self._temp_copy_param_without_offload
|
||||
self.group_step = self._group_step_without_offload
|
||||
|
||||
@torch.no_grad()
|
||||
def _temp_copy_param_with_offload(self, group_to_paramlist):
|
||||
def temp_copy_param(self, group_to_paramlist):
|
||||
for group_id, params in group_to_paramlist.items():
|
||||
for param in params:
|
||||
if hasattr(param, "selected_grad"):
|
||||
temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len(
|
||||
param.shape) != 1 else param.data.clone().detach()
|
||||
param.temp_selected_param = temp_selected_param.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def _temp_copy_param_without_offload(self, group_to_paramlist):
|
||||
for group_id, params in group_to_paramlist.items():
|
||||
for param in params:
|
||||
if hasattr(param, "selected_grad"):
|
||||
param.temp_selected_param = param.data[:, param.selected_indices].clone().detach() if len(
|
||||
param.shape) != 1 else param.data.clone().detach()
|
||||
if self.offload:
|
||||
param.temp_selected_param = temp_selected_param.cpu()
|
||||
else:
|
||||
param.temp_selected_param = temp_selected_param
|
||||
|
||||
def copy_mv_from_cpu(self, params):
|
||||
for param in params:
|
||||
@ -167,6 +157,13 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
|
||||
@torch.no_grad()
|
||||
def _step_with_offload(self):
|
||||
"""
|
||||
Performs parameter updates in offload mode.
|
||||
|
||||
In this mode, group_step() calls adamw() on each pre-partitioned param bucket,
|
||||
so memory can be released after each bucket update to reduce GPU overhead.
|
||||
Without offload, adamw() is called directly for speed.
|
||||
"""
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
params = group["params"]
|
||||
|
||||
@ -197,10 +194,127 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
flush_bucket()
|
||||
|
||||
@torch.no_grad()
|
||||
def _group_step_without_offload(self, group_to_paramlist):
|
||||
def group_step(self, group_to_paramlist):
|
||||
for group_id, params in group_to_paramlist.items():
|
||||
group = self.param_groups[group_id]
|
||||
|
||||
if self.offload:
|
||||
self.copy_mv_from_cpu(params)
|
||||
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
max_exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
|
||||
amsgrad: bool = group["amsgrad"]
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
|
||||
for param in params:
|
||||
if hasattr(param, "selected_grad"):
|
||||
is_2d = (len(param.shape) != 1)
|
||||
selected_param = param.data[:, param.selected_indices] if is_2d else param.data
|
||||
|
||||
state = self.state.setdefault(param, {})
|
||||
if len(state) == 0:
|
||||
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
|
||||
if amsgrad:
|
||||
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)
|
||||
if not self.offload:
|
||||
state["exp_avg"] = torch.zeros_like(selected_param)
|
||||
state["exp_avg_sq"] = torch.zeros_like(selected_param)
|
||||
|
||||
if self.offload:
|
||||
exp_avg_t = param.exp_avg.view_as(selected_param)
|
||||
exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param)
|
||||
else:
|
||||
exp_avg_t = state["exp_avg"]
|
||||
exp_avg_sq_t = state["exp_avg_sq"]
|
||||
|
||||
params_with_grad.append(selected_param)
|
||||
grads.append(param.selected_grad)
|
||||
exp_avgs.append(exp_avg_t)
|
||||
exp_avg_sqs.append(exp_avg_sq_t)
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
|
||||
state_steps.append(state["step"])
|
||||
|
||||
adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
eps=group["eps"],
|
||||
maximize=False,
|
||||
)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
if hasattr(param, "selected_grad") and len(param.shape) != 1:
|
||||
param.data[:, param.selected_indices] = params_with_grad[i]
|
||||
|
||||
if self.offload:
|
||||
self.copy_mv_to_cpu(params)
|
||||
|
||||
for param in params:
|
||||
param.selected_grad = None
|
||||
|
||||
|
||||
class ZenFlowSelectiveAdamW_stage3(torch.optim.AdamW):
|
||||
|
||||
def __init__(self, *args, offload=False, bucket_size=5e8, **kwargs):
|
||||
super(ZenFlowSelectiveAdamW_stage3, self).__init__(*args, **kwargs)
|
||||
self.offload = offload
|
||||
|
||||
if offload:
|
||||
self.step = self._step_with_offload
|
||||
self.bucket_size = bucket_size
|
||||
else:
|
||||
self.step = self._step_without_offload
|
||||
|
||||
@torch.no_grad()
|
||||
def temp_copy_param(self, paramlist):
|
||||
for param in paramlist:
|
||||
if hasattr(param, "selected_grad"):
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset, param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
temp_selected_param = param_2d[param.selected_indices, :].clone().detach()
|
||||
else:
|
||||
temp_selected_param = param.ds_tensor.data.clone().detach()
|
||||
|
||||
if self.offload:
|
||||
param.temp_selected_param = temp_selected_param.cpu()
|
||||
else:
|
||||
param.temp_selected_param = temp_selected_param
|
||||
|
||||
def clear_selected_mv(self):
|
||||
print("Zenflow: clearing selective optimizer states...")
|
||||
for group in self.param_groups:
|
||||
for param in group['params']:
|
||||
state = self.state.setdefault(param, {})
|
||||
if len(state) == 0:
|
||||
continue
|
||||
if self.offload:
|
||||
param.exp_avg_cpu_data.zero_()
|
||||
param.exp_avg_sq_cpu_data.zero_()
|
||||
else:
|
||||
state["exp_avg"].zero_()
|
||||
state["exp_avg_sq"].zero_()
|
||||
|
||||
@torch.no_grad()
|
||||
def _step_without_offload(self):
|
||||
for group in self.param_groups:
|
||||
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
@ -209,10 +323,18 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
state_steps: List[Tensor] = []
|
||||
amsgrad: bool = group["amsgrad"]
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
|
||||
for param in params:
|
||||
for param in group["params"]:
|
||||
if hasattr(param, "selected_grad"):
|
||||
selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
selected_param = param_2d[param.selected_indices, :]
|
||||
else:
|
||||
selected_param = param.ds_tensor.data
|
||||
if hasattr(param, 'temp_selected_param') and param.temp_selected_param is not None:
|
||||
selected_param.copy_(param.temp_selected_param)
|
||||
|
||||
state = self.state.setdefault(param, {})
|
||||
if len(state) == 0:
|
||||
@ -229,7 +351,6 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
|
||||
state_steps.append(state["step"])
|
||||
|
||||
adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
@ -245,44 +366,91 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
eps=group["eps"],
|
||||
maximize=False,
|
||||
)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
for i, param in enumerate(group["params"]):
|
||||
if hasattr(param, "selected_grad"):
|
||||
if len(param.shape) != 1:
|
||||
param.data[:, param.selected_indices] = params_with_grad[i]
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
param_2d[param.selected_indices, :] = params_with_grad[i]
|
||||
|
||||
for param in params:
|
||||
param.selected_grad = None
|
||||
for param in group["params"]:
|
||||
if hasattr(param, "temp_selected_param"):
|
||||
param.temp_selected_param = None
|
||||
param.selected_grad = None
|
||||
|
||||
def copy_mv_from_cpu(self, params):
|
||||
for param in params:
|
||||
param.exp_avg = param.exp_avg_cpu_data.to(param.device, non_blocking=True)
|
||||
param.exp_avg_sq = param.exp_avg_sq_cpu_data.to(param.device, non_blocking=True)
|
||||
|
||||
def copy_mv_to_cpu(self, params):
|
||||
for param in params:
|
||||
param.exp_avg_cpu_data.copy_(param.exp_avg.data, non_blocking=True)
|
||||
param.exp_avg_sq_cpu_data.copy_(param.exp_avg_sq.data, non_blocking=True)
|
||||
param.exp_avg = None
|
||||
param.exp_avg_sq = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _group_step_with_offload(self, group_to_paramlist):
|
||||
for group_id, params in group_to_paramlist.items():
|
||||
def group_step(self, paramlist):
|
||||
|
||||
group_to_paramlist = {}
|
||||
for param in paramlist:
|
||||
group_id = param.group_id
|
||||
if group_id not in group_to_paramlist:
|
||||
group_to_paramlist[group_id] = []
|
||||
group_to_paramlist[group_id].append(param)
|
||||
|
||||
for group_id in sorted(group_to_paramlist.keys()):
|
||||
params = group_to_paramlist[group_id]
|
||||
group = self.param_groups[group_id]
|
||||
|
||||
self.copy_mv_from_cpu(params)
|
||||
if self.offload:
|
||||
self.copy_mv_from_cpu(params)
|
||||
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
max_exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
|
||||
amsgrad: bool = group["amsgrad"]
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
|
||||
for param in params:
|
||||
if hasattr(param, "selected_grad"):
|
||||
selected_param = param.data[:, param.selected_indices] if len(param.shape) != 1 else param.data
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
selected_param = param_2d[param.selected_indices, :]
|
||||
else:
|
||||
selected_param = param.ds_tensor.data
|
||||
|
||||
state = self.state.setdefault(param, {})
|
||||
if len(state) == 0:
|
||||
state["step"] = torch.zeros((), dtype=param.dtype, device=selected_param.device)
|
||||
if amsgrad:
|
||||
state["max_exp_avg_sq"] = torch.zeros_like(selected_param)
|
||||
if not self.offload:
|
||||
state["exp_avg"] = torch.zeros_like(selected_param)
|
||||
state["exp_avg_sq"] = torch.zeros_like(selected_param)
|
||||
|
||||
if self.offload:
|
||||
exp_avg_t = param.exp_avg.view_as(selected_param)
|
||||
exp_avg_sq_t = param.exp_avg_sq.view_as(selected_param)
|
||||
else:
|
||||
exp_avg_t = state["exp_avg"]
|
||||
exp_avg_sq_t = state["exp_avg_sq"]
|
||||
|
||||
params_with_grad.append(selected_param)
|
||||
grads.append(param.selected_grad)
|
||||
exp_avgs.append(param.exp_avg.view_as(selected_param))
|
||||
exp_avg_sqs.append(param.exp_avg_sq.view_as(selected_param))
|
||||
exp_avgs.append(exp_avg_t)
|
||||
exp_avg_sqs.append(exp_avg_sq_t)
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
|
||||
state_steps.append(state["step"])
|
||||
@ -305,14 +473,64 @@ class ZenFlowSelectiveAdamW(torch.optim.AdamW):
|
||||
|
||||
for i, param in enumerate(params):
|
||||
if hasattr(param, "selected_grad"):
|
||||
if len(param.shape) != 1:
|
||||
param.data[:, param.selected_indices] = params_with_grad[i]
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
param_2d[param.selected_indices, :] = params_with_grad[i]
|
||||
|
||||
self.copy_mv_to_cpu(params)
|
||||
if self.offload:
|
||||
self.copy_mv_to_cpu(params)
|
||||
|
||||
for param in params:
|
||||
param.selected_grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _step_with_offload(self):
|
||||
"""
|
||||
Performs parameter updates in offload mode.
|
||||
|
||||
In this mode, group_step() calls adamw() on each pre-partitioned param bucket,
|
||||
so memory can be released after each bucket update to reduce GPU overhead.
|
||||
Without offload, adamw() is called directly for speed.
|
||||
"""
|
||||
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
params = group["params"]
|
||||
|
||||
bucket = []
|
||||
bucket_numel = 0
|
||||
|
||||
def flush_bucket():
|
||||
if not bucket:
|
||||
return
|
||||
for param in bucket:
|
||||
if hasattr(param, "temp_selected_param") and param.temp_selected_param is not None:
|
||||
temp_selected_param = param.temp_selected_param.to(param.device, non_blocking=True)
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
if num_row != 1:
|
||||
param_2d = param.ds_tensor.data.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(
|
||||
param.complete_numel // num_row, num_row)
|
||||
param_2d[param.selected_indices, :] = temp_selected_param
|
||||
else:
|
||||
param.ds_tensor.data.copy_(temp_selected_param)
|
||||
param.temp_selected_param = None
|
||||
|
||||
self.group_step(bucket)
|
||||
bucket.clear()
|
||||
|
||||
for param in params:
|
||||
if hasattr(param, "selected_grad"):
|
||||
bucket.append(param)
|
||||
bucket_numel += param.numel()
|
||||
if bucket_numel >= self.bucket_size:
|
||||
flush_bucket()
|
||||
bucket_numel = 0
|
||||
|
||||
flush_bucket()
|
||||
|
||||
|
||||
def _single_tensor_adamw(
|
||||
params: List[Tensor],
|
||||
|
@ -1868,6 +1868,7 @@ class DeepSpeedEngine(Module):
|
||||
overlap_comm=self.zero_overlap_comm(),
|
||||
offload_optimizer_config=self.zero_offload_optimizer(),
|
||||
offload_param_config=self.zero_offload_param(),
|
||||
zenflow_config=self.zenflow_config(),
|
||||
sub_group_size=self.zero_sub_group_size(),
|
||||
offload_ratio=self.zero_partial_offload(),
|
||||
mpu=self.mpu,
|
||||
|
641
deepspeed/runtime/zenflow/engine_stage3.py
Normal file
641
deepspeed/runtime/zenflow/engine_stage3.py
Normal file
@ -0,0 +1,641 @@
|
||||
# Copyright (c) DeepSpeed Team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from deepspeed.runtime.zero.partition_parameters import *
|
||||
|
||||
import torch
|
||||
import math
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.ops.adam import ZenFlowSelectiveAdamW_stage3
|
||||
from deepspeed.runtime.utils import see_memory_usage
|
||||
from typing import List
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from typing import TYPE_CHECKING
|
||||
from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
||||
|
||||
OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state'
|
||||
INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
|
||||
OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state'
|
||||
OPTIMIZER_STEP_TIMER = 'optimizer_step'
|
||||
|
||||
|
||||
def configure_zenflow(optimizer_z3, zenflow_config):
|
||||
|
||||
optimizer_z3.select_strategy = zenflow_config.select_strategy
|
||||
if optimizer_z3.select_strategy == 'auto':
|
||||
optimizer_z3.select_strategy = "epoch"
|
||||
if isinstance(zenflow_config.select_interval, int):
|
||||
raise Warning(
|
||||
"If use auto select strategy, select_interval will be set to 1 and select_strategy will be set to epoch, thus select_interval would be overwritten."
|
||||
)
|
||||
optimizer_z3.select_interval = 1
|
||||
else:
|
||||
if isinstance(zenflow_config.select_interval, str):
|
||||
raise ValueError("If don't use auto select strategy, select_interval must be a number.")
|
||||
optimizer_z3.select_interval = int(zenflow_config.select_interval)
|
||||
|
||||
if isinstance(zenflow_config.update_interval, str):
|
||||
optimizer_z3.auto_update = True
|
||||
optimizer_z3.update_interval = 0
|
||||
else:
|
||||
optimizer_z3.auto_update = False
|
||||
optimizer_z3.update_interval = int(zenflow_config.update_interval)
|
||||
|
||||
if optimizer_z3.select_strategy == 'epoch':
|
||||
if zenflow_config.steps_per_epoch is not None:
|
||||
optimizer_z3.select_interval = optimizer_z3.select_interval * zenflow_config.steps_per_epoch
|
||||
else:
|
||||
optimizer_z3.select_interval = 0
|
||||
|
||||
if not optimizer_z3.auto_update and optimizer_z3.select_interval != 0 and optimizer_z3.select_interval < optimizer_z3.update_interval:
|
||||
raise ValueError("Select interval must be greater or equal to update interval")
|
||||
|
||||
optimizer_z3.topk_ratio = zenflow_config.topk_ratio
|
||||
|
||||
optimizer_z3.param_id_grad_sum_buffer_offset = {}
|
||||
|
||||
optimizer_z3.zf_stage3 = True
|
||||
|
||||
if optimizer_z3.auto_update:
|
||||
optimizer_z3.param_id_sum_buffer_offset = {}
|
||||
optimizer_z3.auto_ratio = zenflow_config.auto_ratio
|
||||
optimizer_z3.zenflow_need_update = [False, False]
|
||||
optimizer_z3.zenflow_state = 0
|
||||
optimizer_z3.num_need_update = 0
|
||||
|
||||
|
||||
def _initialize_zenflow_stage3_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3",
|
||||
module,
|
||||
zenflow_config: dict = None):
|
||||
|
||||
optimizer_z3.zenflow = True if zenflow_config is not None else False
|
||||
|
||||
if not optimizer_z3.zenflow:
|
||||
return
|
||||
|
||||
optimizer_z3.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
|
||||
|
||||
for p in module.parameters():
|
||||
p.data = p.data.t().contiguous() if len(p.shape) != 1 else p.data
|
||||
|
||||
|
||||
def _initialize_zenflow_stage3_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3",
|
||||
zenflow_config: dict = None,
|
||||
overlap_comm: bool = False):
|
||||
|
||||
if not optimizer_z3.zenflow:
|
||||
return
|
||||
|
||||
optimizer_z3.micro_step = -1
|
||||
optimizer_z3.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
|
||||
optimizer_z3.offload_selective_optimizer = zenflow_config.offload
|
||||
optimizer_z3.zenflow_overlap_step = zenflow_config.overlap_step
|
||||
|
||||
if optimizer_z3.offload_selective_optimizer:
|
||||
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
|
||||
|
||||
if optimizer_z3.zenflow_overlap_step:
|
||||
optimizer_z3.process_optimizer_established = False
|
||||
optimizer_z3.first_update_round_after_warmup = True
|
||||
optimizer_z3.initialize_optimizer_states = lambda: initialize_optimizer_states(optimizer_z3)
|
||||
optimizer_z3.step = lambda closure=None: step(optimizer_z3, closure)
|
||||
optimizer_z3.zenflow_cpu_optimizer_overlap_step = lambda now_state, scaled_global_grad_norm: zenflow_cpu_optimizer_overlap_step(
|
||||
optimizer_z3, now_state, scaled_global_grad_norm)
|
||||
optimizer_z3.wait_last_update_and_copy = lambda timer_names: wait_last_update_and_copy(
|
||||
optimizer_z3, timer_names)
|
||||
optimizer_z3.partition_grads = lambda params_to_release, grad_partitions: partition_grads(
|
||||
optimizer_z3, params_to_release, grad_partitions)
|
||||
optimizer_z3.get_overlap_step_state = lambda: get_overlap_step_state(optimizer_z3)
|
||||
optimizer_z3.start_optimizer_process = lambda: start_optimizer_process(optimizer_z3)
|
||||
optimizer_z3.unscale_and_clip_grads = lambda sub_group_id, total_norm, now_state: unscale_and_clip_grads(
|
||||
optimizer_z3, sub_group_id, total_norm, now_state)
|
||||
|
||||
configure_zenflow(optimizer_z3, zenflow_config)
|
||||
optimizer_z3.selective_optimizer = ZenFlowSelectiveAdamW_stage3([{
|
||||
k: v
|
||||
for k, v in group.items() if k != "params"
|
||||
} | {
|
||||
"params": group["params"]
|
||||
} for group in optimizer_z3.optimizer.param_groups],
|
||||
offload=optimizer_z3.offload_selective_optimizer)
|
||||
optimizer_z3.num_total_param = sum(
|
||||
sum(1 for param in group["params"] if len(param.ds_shape) != 1)
|
||||
for group in optimizer_z3.optimizer.param_groups)
|
||||
|
||||
|
||||
def zenflow_cpu_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
return optimizer_z3.optimizer.step(step_id=optimizer_z3.micro_step + 1)
|
||||
|
||||
|
||||
def _sync_selective_optimizer_lr(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
for group_selected, group in zip(optimizer_z3.selective_optimizer.param_groups,
|
||||
optimizer_z3.optimizer.param_groups):
|
||||
group_selected["lr"] = group["lr"]
|
||||
|
||||
|
||||
def selective_optimizer_step(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
optimizer_z3.selective_optimizer.step()
|
||||
|
||||
|
||||
def is_zenflow_select_boundary(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> bool:
|
||||
return optimizer_z3.zenflow and (optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) >= 0 and (
|
||||
(optimizer_z3.micro_step - optimizer_z3.full_warm_up_rounds) == 0 or
|
||||
(optimizer_z3.select_interval != 0 and optimizer_z3.micro_step % optimizer_z3.select_interval == 0))
|
||||
|
||||
|
||||
def update_selected_channels(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3", params_to_update, grad_partitions):
|
||||
src_rk = dist.get_rank(optimizer_z3.dp_process_group)
|
||||
total_rk = dist.get_world_size(optimizer_z3.dp_process_group)
|
||||
|
||||
total_chunk_size = 0
|
||||
param_local_offset = [0 for _ in range(total_rk)]
|
||||
|
||||
for param, grad_partition in zip(params_to_update, grad_partitions):
|
||||
param_max_chunk_size = 0
|
||||
param_rk_offset = 0
|
||||
for rk in range(total_rk):
|
||||
contains_real_data = param.partition_numel() * rk < param.ds_numel
|
||||
if not contains_real_data:
|
||||
param.grad = None
|
||||
continue
|
||||
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
if num_row == 1:
|
||||
continue
|
||||
|
||||
partition_size = param.partition_numel()
|
||||
start = partition_size * rk
|
||||
end = min(start + partition_size, param.ds_numel)
|
||||
|
||||
start_idx = math.ceil(start / num_row)
|
||||
end_idx = end // num_row
|
||||
num_cols = end_idx - start_idx
|
||||
|
||||
if param.ds_id not in optimizer_z3.param_id_grad_sum_buffer_offset:
|
||||
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = []
|
||||
|
||||
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id].append(
|
||||
(param_local_offset[rk], num_cols, param_rk_offset))
|
||||
|
||||
param_max_chunk_size = max(param_max_chunk_size, num_cols)
|
||||
param_rk_offset += num_cols
|
||||
param_local_offset[rk] += num_cols
|
||||
|
||||
total_chunk_size += param_max_chunk_size
|
||||
|
||||
optimizer_z3.grad_sum_buffer = torch.zeros(total_chunk_size, dtype=optimizer_z3.dtype, device='cuda')
|
||||
|
||||
for param, grad_partition in zip(params_to_update, grad_partitions):
|
||||
contains_real_data = param.partition_numel() * src_rk < param.ds_numel
|
||||
if not contains_real_data:
|
||||
# this grad partition is empty - don't need to do anything
|
||||
param.grad = None
|
||||
continue
|
||||
|
||||
#ds_shape is the transposed shape, it should not be same as param.shape
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
|
||||
if num_row == 1:
|
||||
continue
|
||||
|
||||
partition_size = param.partition_numel()
|
||||
start = partition_size * src_rk
|
||||
end = min(start + partition_size, param.ds_numel)
|
||||
|
||||
start_idx = math.ceil(start / num_row)
|
||||
end_idx = end // num_row
|
||||
|
||||
num_elements = (end_idx - start_idx) * num_row
|
||||
|
||||
param.complete_column_offset = start_idx * num_row - start
|
||||
param.complete_numel = (end_idx - start_idx) * num_row
|
||||
|
||||
sum_per_column = grad_partition.narrow(0, param.complete_column_offset, num_elements)
|
||||
sum_per_column = sum_per_column.view(end_idx - start_idx, num_row)
|
||||
sum_array = sum_per_column.abs().sum(dim=1)
|
||||
|
||||
offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk]
|
||||
optimizer_z3.grad_sum_buffer.narrow(0, offset, length).copy_(sum_array)
|
||||
|
||||
gathered_chunks = [torch.zeros_like(optimizer_z3.grad_sum_buffer) for _ in range(total_rk)]
|
||||
dist.all_gather(gathered_chunks, optimizer_z3.grad_sum_buffer, group=optimizer_z3.dp_process_group)
|
||||
|
||||
for param in params_to_update:
|
||||
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
|
||||
if num_row == 1:
|
||||
continue
|
||||
|
||||
param_column_sum = []
|
||||
for rk in range(total_rk):
|
||||
offset, length, _ = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][rk]
|
||||
param_column_sum.append(gathered_chunks[rk].narrow(0, offset, length))
|
||||
global_param_column_sum = torch.cat(param_column_sum, dim=0)
|
||||
|
||||
num_select = max(1, int(global_param_column_sum.numel() * optimizer_z3.topk_ratio))
|
||||
_, global_topk_indices = torch.topk(global_param_column_sum, num_select, largest=True)
|
||||
|
||||
_, length, rk_offset = optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id][src_rk]
|
||||
local_indices = [(idx.item() - rk_offset) for idx in global_topk_indices
|
||||
if rk_offset <= idx < rk_offset + length]
|
||||
param.selected_indices = torch.tensor(local_indices, device='cuda')
|
||||
optimizer_z3.param_id_grad_sum_buffer_offset[param.ds_id] = []
|
||||
|
||||
optimizer_z3.grad_sum_buffer = None
|
||||
|
||||
|
||||
def _process_selected_fp32_groups_grad(optimizer_z3, params_to_update, grad_partitions):
|
||||
|
||||
if optimizer_z3.auto_update:
|
||||
optimizer_z3.sum_buffer = torch.zeros(optimizer_z3.num_total_param, dtype=optimizer_z3.dtype, device='cuda')
|
||||
optimizer_z3.critic_sum_buffer = torch.zeros(optimizer_z3.num_total_param,
|
||||
dtype=optimizer_z3.dtype,
|
||||
device='cuda')
|
||||
curr_buffer_idx = 0
|
||||
|
||||
for param, grad_partition in zip(params_to_update, grad_partitions):
|
||||
|
||||
rk = dist.get_rank(optimizer_z3.dp_process_group)
|
||||
|
||||
contains_real_data = param.partition_numel() * rk < param.ds_numel
|
||||
if not contains_real_data:
|
||||
# this grad partition is empty - don't need to do anything
|
||||
param.grad = None
|
||||
continue
|
||||
|
||||
#ds_shape is the transposed shape, it should not be same as param.shape
|
||||
num_column, num_row = param.ds_shape if len(param.ds_shape) != 1 else (param.ds_shape[0], 1)
|
||||
|
||||
if num_row == 1:
|
||||
param.selected_grad = grad_partition.clone().detach()
|
||||
else:
|
||||
grad_2d = grad_partition.narrow(0, param.complete_column_offset,
|
||||
param.complete_numel).view(param.complete_numel // num_row, num_row)
|
||||
param.selected_grad = grad_2d[param.selected_indices, :].clone().detach()
|
||||
|
||||
if optimizer_z3.auto_update:
|
||||
optimizer_z3.sum_buffer[curr_buffer_idx] = grad_partition.abs().sum()
|
||||
optimizer_z3.critic_sum_buffer[curr_buffer_idx] = param.selected_grad.abs().sum()
|
||||
curr_buffer_idx += 1
|
||||
|
||||
if optimizer_z3.offload_selective_optimizer and not hasattr(param, 'exp_avg_cpu_data'):
|
||||
buffer = torch.zeros(param.selected_grad.numel(), dtype=param.dtype, device=optimizer_z3.device)
|
||||
param.exp_avg_cpu_data = get_accelerator().pin_memory(
|
||||
buffer) if optimizer_z3.offload_optimizer_pin_memory else buffer
|
||||
param.exp_avg_sq_cpu_data = get_accelerator().pin_memory(
|
||||
buffer.clone()) if optimizer_z3.offload_optimizer_pin_memory else buffer.clone()
|
||||
|
||||
if optimizer_z3.auto_update:
|
||||
total_rk = dist.get_world_size(optimizer_z3.dp_process_group)
|
||||
sum_gather_list = [torch.zeros_like(optimizer_z3.sum_buffer) for _ in range(total_rk)]
|
||||
critic_gather_list = [torch.zeros_like(optimizer_z3.critic_sum_buffer) for _ in range(total_rk)]
|
||||
curr_buffer_idx = 0
|
||||
|
||||
dist.all_gather(sum_gather_list, optimizer_z3.sum_buffer, group=optimizer_z3.dp_process_group)
|
||||
dist.all_gather(critic_gather_list, optimizer_z3.critic_sum_buffer, group=optimizer_z3.dp_process_group)
|
||||
|
||||
for param in params_to_update:
|
||||
if len(param.ds_shape) == 1:
|
||||
continue
|
||||
|
||||
if not hasattr(param, 'non_critic_sum'):
|
||||
param.non_critic_sum = 0
|
||||
if not hasattr(param, 'avg_critic_sum'):
|
||||
param.avg_critic_sum = 0
|
||||
|
||||
grad_total_sum = sum(sum_gather_list[rk][curr_buffer_idx] for rk in range(total_rk))
|
||||
grad_critic_sum = sum(critic_gather_list[rk][curr_buffer_idx] for rk in range(total_rk))
|
||||
|
||||
param.avg_critic_sum = (param.avg_critic_sum * (optimizer_z3.update_interval - 1) +
|
||||
grad_critic_sum) / optimizer_z3.update_interval / (optimizer_z3.topk_ratio * 10)
|
||||
param.non_critic_sum += (grad_total_sum - grad_critic_sum) / ((1 - optimizer_z3.topk_ratio) * 10)
|
||||
if param.non_critic_sum >= param.avg_critic_sum:
|
||||
optimizer_z3.num_need_update += 1
|
||||
if optimizer_z3.num_need_update >= int(optimizer_z3.auto_ratio * optimizer_z3.num_total_param):
|
||||
optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = True
|
||||
|
||||
curr_buffer_idx += 1
|
||||
|
||||
if not optimizer_z3.is_gradient_accumulation_boundary:
|
||||
optimizer_z3.selective_optimizer.group_step(params_to_update)
|
||||
else:
|
||||
optimizer_z3.selective_optimizer.temp_copy_param(params_to_update)
|
||||
|
||||
if optimizer_z3.auto_update:
|
||||
optimizer_z3.sum_buffer = None
|
||||
optimizer_z3.critic_sum_buffer = None
|
||||
|
||||
|
||||
def sync_fp32_param_from_gpu(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
|
||||
if optimizer_z3.micro_step == 0:
|
||||
return
|
||||
|
||||
for fp16_partitions, fp32_partition in zip(optimizer_z3.fp16_partitioned_groups_flat,
|
||||
optimizer_z3.fp32_partitioned_groups_flat):
|
||||
fp32_partition.data.copy_(fp16_partitions.data)
|
||||
|
||||
|
||||
def zenflow_backward_prologue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
optimizer_z3.micro_step += 1
|
||||
if optimizer_z3.auto_update:
|
||||
optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state] = False
|
||||
optimizer_z3.num_need_update = 0
|
||||
if optimizer_z3.zenflow_need_update[optimizer_z3.zenflow_state ^ 1]:
|
||||
optimizer_z3.update_interval = 0
|
||||
for group in optimizer_z3.fp16_groups:
|
||||
for p in group:
|
||||
p.non_critic_sum = 0
|
||||
optimizer_z3.update_interval += 1
|
||||
if optimizer_z3.is_zenflow_select_boundary():
|
||||
sync_fp32_param_from_gpu(optimizer_z3)
|
||||
optimizer_z3.selective_optimizer.clear_selected_mv()
|
||||
|
||||
|
||||
def zenflow_backward_epilogue(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
optimizer_z3._partition_all_parameters()
|
||||
|
||||
|
||||
def log_selective_optimizer_timers(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
pass
|
||||
|
||||
|
||||
def initialize_optimizer_states(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3"):
|
||||
num_subgroups = len(optimizer_z3.fp16_groups)
|
||||
|
||||
largest_numel = max([sum([p.ds_numel for p in psg]) for psg in optimizer_z3.fp16_partitioned_groups])
|
||||
gradient_dtype = optimizer_z3.fp32_partitioned_groups_flat[0].dtype
|
||||
gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=optimizer_z3.device)
|
||||
|
||||
timer_names = set()
|
||||
|
||||
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
|
||||
# which do lazy initialization of the state at the first call to step.
|
||||
is_adagrad = isinstance(optimizer_z3.optimizer, torch.optim.Adagrad)
|
||||
|
||||
if optimizer_z3.swap_optimizer:
|
||||
optimizer_z3.optimizer_swapper.init_timers()
|
||||
|
||||
timer_names.add(INIT_OPTIMIZER_TIMER)
|
||||
optimizer_z3.timers(INIT_OPTIMIZER_TIMER).start()
|
||||
|
||||
for i, group in enumerate(optimizer_z3.fp16_groups):
|
||||
swappable_optimizer_subgroup = optimizer_z3._swappable_optimizer_subgroup(i)
|
||||
swappable_param_subgroup = optimizer_z3.fp16_partitioned_groups_flat[i] is None
|
||||
|
||||
num_elements = int(optimizer_z3.fp16_partitioned_groups_flat_numel[i])
|
||||
|
||||
see_memory_usage(
|
||||
f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
|
||||
force=False)
|
||||
|
||||
if swappable_optimizer_subgroup:
|
||||
optimizer_z3._optimizer_states_and_gradient_swap_in(i, timer_names)
|
||||
|
||||
if optimizer_z3.offload_optimizer and not swappable_optimizer_subgroup:
|
||||
subgroup_gradient_buffer = torch.zeros(num_elements, dtype=gradient_dtype, device=optimizer_z3.device)
|
||||
if optimizer_z3.offload_optimizer_pin_memory:
|
||||
subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer)
|
||||
|
||||
optimizer_z3.fp32_partitioned_groups_flat[i].grad = None
|
||||
optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad = [
|
||||
subgroup_gradient_buffer.to(optimizer_z3.subgroup_to_device[i]),
|
||||
subgroup_gradient_buffer.clone().to(optimizer_z3.subgroup_to_device[i])
|
||||
]
|
||||
else:
|
||||
optimizer_z3.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements)
|
||||
|
||||
if swappable_param_subgroup:
|
||||
optimizer_z3._partitioned_params_swap_out(i)
|
||||
|
||||
if swappable_optimizer_subgroup:
|
||||
optimizer_z3._optimizer_states_and_gradient_swap_out(i, timer_names)
|
||||
|
||||
see_memory_usage(
|
||||
f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}',
|
||||
force=False)
|
||||
|
||||
# Initialize the optimizer states with the flattened fp32 partition.
|
||||
if is_adagrad:
|
||||
optimizer_z3.optimizer = torch.optim.Adagrad(optimizer_z3.fp32_partitioned_groups_flat,
|
||||
**optimizer_z3.optimizer.defaults)
|
||||
|
||||
optimizer_z3.timers(INIT_OPTIMIZER_TIMER).stop()
|
||||
optimizer_z3.timers.log(timer_names)
|
||||
|
||||
if optimizer_z3.swap_optimizer:
|
||||
optimizer_z3.optimizer_swapper.log_timers()
|
||||
|
||||
if not optimizer_z3.offload_optimizer:
|
||||
for group in optimizer_z3.fp32_partitioned_groups_flat:
|
||||
group.grad = None
|
||||
|
||||
# Reset steps
|
||||
return
|
||||
|
||||
|
||||
def get_overlap_step_state(optimizer_z3: "DeepSpeedZeroOptimizer_Stage3") -> int:
|
||||
if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds:
|
||||
return optimizer_z3.micro_step & 1
|
||||
else:
|
||||
if not optimizer_z3.auto_update:
|
||||
return (optimizer_z3.micro_step // optimizer_z3.update_interval) & 1
|
||||
else:
|
||||
return optimizer_z3.zenflow_state
|
||||
|
||||
|
||||
@instrument_w_nvtx
|
||||
def partition_grads(optimizer_z3, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
|
||||
offload_fp32_gradients = {}
|
||||
offload_fp32_offsets = {}
|
||||
buffers = []
|
||||
for param, grad_partition in zip(params_to_release, grad_partitions):
|
||||
|
||||
contains_real_data = param.partition_numel() * dist.get_rank(optimizer_z3.dp_process_group) < param.ds_numel
|
||||
if not contains_real_data:
|
||||
# this grad partition is empty - don't need to do anything
|
||||
param.grad = None
|
||||
continue
|
||||
|
||||
# move or accumulate gradient partition to target buffer
|
||||
param_id_to_grad_partition = getattr(optimizer_z3,
|
||||
f"_{optimizer_z3.__class__.__name__}__param_id_to_grad_partition")
|
||||
grad_buffer = param_id_to_grad_partition[param.ds_id].narrow(0, 0, grad_partition.numel())
|
||||
buffers.append(grad_buffer)
|
||||
if optimizer_z3.micro_step_id == 0: # don't accumulate
|
||||
grad_buffer.copy_(grad_partition, non_blocking=True)
|
||||
# ensure grad buffer is a CUDA buffer to speed up the next few
|
||||
# operations and so it can be used asynchronously
|
||||
grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True)
|
||||
elif get_accelerator().on_accelerator(grad_buffer):
|
||||
grad_buffer.add_(grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(grad_buffer.shape))
|
||||
else:
|
||||
# if dst is CPU, copy first to src device, do the addition
|
||||
# there, then move back to dst. adding directly to cpu is very slow
|
||||
cuda_grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True)
|
||||
cuda_grad_buffer.add_(
|
||||
grad_partition.to(optimizer_z3.gradient_accumulation_dtype).view(cuda_grad_buffer.shape))
|
||||
grad_buffer.copy_(cuda_grad_buffer, non_blocking=True)
|
||||
# ensure grad buffer is a CUDA buffer to speed up the next few
|
||||
# operations and so it can be used asynchronously
|
||||
grad_buffer = cuda_grad_buffer
|
||||
|
||||
# offload the gradient partition if applicable
|
||||
if optimizer_z3.offload_optimizer:
|
||||
i, dest_offset, _ = optimizer_z3.grad_position[optimizer_z3.get_param_id(param)]
|
||||
now_state = optimizer_z3.get_overlap_step_state()
|
||||
|
||||
if optimizer_z3.is_gradient_accumulation_boundary:
|
||||
optimizer_z3.norm_for_param_grads[optimizer_z3.get_param_id(
|
||||
param)] = optimizer_z3._constant_buffered_norm2(grad_buffer)
|
||||
|
||||
if optimizer_z3._swappable_optimizer_subgroup(i):
|
||||
if not i in offload_fp32_gradients.keys():
|
||||
offload_fp32_gradients[i] = []
|
||||
offload_fp32_offsets[i] = []
|
||||
|
||||
offload_fp32_gradients[i].append(grad_buffer.float())
|
||||
offload_fp32_offsets[i].append(dest_offset)
|
||||
else:
|
||||
fp32_grad_tensor = optimizer_z3.fp32_partitioned_groups_flat[i].overlap_grad[now_state].narrow(
|
||||
0, dest_offset, grad_buffer.numel())
|
||||
fp32_grad_tensor.copy_(grad_buffer.float())
|
||||
|
||||
# free the gradient
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
if param.grad is not None:
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
param.grad = None
|
||||
|
||||
if optimizer_z3.offload_optimizer and optimizer_z3.swap_optimizer:
|
||||
for i in offload_fp32_gradients.keys():
|
||||
optimizer_z3.optimizer_swapper.swap_out_gradients(parameter=optimizer_z3.fp32_partitioned_groups_flat[i],
|
||||
gradient_offsets=offload_fp32_offsets[i],
|
||||
gradient_tensors=offload_fp32_gradients[i])
|
||||
return buffers
|
||||
|
||||
|
||||
@instrument_w_nvtx
|
||||
def unscale_and_clip_grads(self, sub_group_id, total_norm, now_state):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
if self.clip_grad > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
||||
clip = torch.clamp(clip, min=1.0)
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].overlap_grad[now_state].mul_(1. / combined_scale)
|
||||
|
||||
|
||||
def zenflow_cpu_optimizer_overlap_step(optimizer_z3, now_state, scaled_global_grad_norm):
|
||||
|
||||
if not optimizer_z3.process_optimizer_established:
|
||||
optimizer_z3.start_optimizer_process()
|
||||
|
||||
group_infos = []
|
||||
for group_no, group in enumerate(optimizer_z3.fp16_groups):
|
||||
optimizer_z3.unscale_and_clip_grads(group_no, scaled_global_grad_norm, now_state)
|
||||
param_group_id = optimizer_z3.sub_group_to_group_id[group_no]
|
||||
|
||||
group_info = {
|
||||
"lr": optimizer_z3.optimizer.param_groups[param_group_id]["lr"],
|
||||
"betas": optimizer_z3.optimizer.param_groups[param_group_id]["betas"],
|
||||
"eps": optimizer_z3.optimizer.param_groups[param_group_id]["eps"],
|
||||
"weight_decay": optimizer_z3.optimizer.param_groups[param_group_id]["weight_decay"],
|
||||
"bias_correction": optimizer_z3.optimizer.param_groups[param_group_id]["bias_correction"],
|
||||
}
|
||||
|
||||
group_infos.append(group_info)
|
||||
|
||||
optimizer_z3.parent_conn.send({
|
||||
"type": "step",
|
||||
"now_state": now_state,
|
||||
"micro_step": optimizer_z3.micro_step,
|
||||
"group_infos": group_infos
|
||||
})
|
||||
|
||||
|
||||
def wait_last_update_and_copy(optimizer_z3, timer_names):
|
||||
|
||||
if not hasattr(optimizer_z3, 'parent_conn'):
|
||||
return
|
||||
|
||||
if optimizer_z3.micro_step + 1 > optimizer_z3.full_warm_up_rounds and optimizer_z3.first_update_round_after_warmup:
|
||||
optimizer_z3.first_update_round_after_warmup = False
|
||||
return
|
||||
|
||||
msg = optimizer_z3.parent_conn.recv()
|
||||
assert msg["type"] == "done", "Optimizer process did not finish stepping correctly."
|
||||
|
||||
for sub_group_id, group in enumerate(optimizer_z3.fp16_groups):
|
||||
if optimizer_z3.fp16_partitioned_groups_flat[sub_group_id] is not None:
|
||||
optimizer_z3.fp16_partitioned_groups_flat[sub_group_id].data.copy_(
|
||||
optimizer_z3.fp32_partitioned_groups_flat[sub_group_id].stale_param.data)
|
||||
|
||||
#unflatten fp16 parameter subgroup
|
||||
optimizer_z3._unflatten_partitioned_parameters(sub_group_id)
|
||||
else:
|
||||
optimizer_z3._partitioned_params_swap_out(sub_group_id)
|
||||
|
||||
optimizer_z3._post_step(timer_names)
|
||||
|
||||
# warn user about caching allocator flushes
|
||||
memory_stats = get_accelerator().memory_stats()
|
||||
alloc_retries = memory_stats.get("num_alloc_retries")
|
||||
if alloc_retries is None:
|
||||
alloc_retries = 0
|
||||
if alloc_retries > optimizer_z3.n_caching_allocator_flushes:
|
||||
if dist.get_rank() == 0:
|
||||
logger.warning(
|
||||
"%d pytorch allocator cache flushes since last step. this happens "
|
||||
"when there is high memory pressure and is detrimental to "
|
||||
"performance. if this is happening frequently consider adjusting "
|
||||
"settings to reduce memory consumption. If you are unable to "
|
||||
"make the cache flushes go away consider adding "
|
||||
"get_accelerator().empty_cache() calls in your training loop to ensure "
|
||||
"that all ranks flush their caches at the same time",
|
||||
alloc_retries - optimizer_z3.n_caching_allocator_flushes)
|
||||
optimizer_z3.n_caching_allocator_flushes = alloc_retries
|
||||
|
||||
|
||||
@instrument_w_nvtx
|
||||
def step(optimizer_z3, closure=None):
|
||||
"""
|
||||
Not supporting closure.
|
||||
"""
|
||||
optimizer_z3._pre_step()
|
||||
optimizer_z3._partition_all_parameters()
|
||||
|
||||
#checks for overflow, adjust the loss scale accordingly
|
||||
if optimizer_z3._overflow_check_and_loss_scale_update():
|
||||
if optimizer_z3.swap_optimizer:
|
||||
optimizer_z3.optimizer_swapper.log_timers()
|
||||
return
|
||||
|
||||
norm_groups = optimizer_z3._get_norm_groups()
|
||||
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
|
||||
|
||||
# Stash unscaled gradient norm
|
||||
optimizer_z3._global_grad_norm = scaled_global_grad_norm / optimizer_z3.loss_scale
|
||||
|
||||
if optimizer_z3.micro_step < optimizer_z3.full_warm_up_rounds:
|
||||
optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm)
|
||||
|
||||
timer_names = set()
|
||||
|
||||
timer_names.add(OPTIMIZER_STEP_TIMER)
|
||||
|
||||
optimizer_z3.wait_last_update_and_copy(timer_names)
|
||||
|
||||
if optimizer_z3.micro_step >= optimizer_z3.full_warm_up_rounds:
|
||||
optimizer_z3.zenflow_cpu_optimizer_overlap_step(optimizer_z3.get_overlap_step_state(), scaled_global_grad_norm)
|
||||
|
||||
return
|
@ -3,14 +3,11 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import math
|
||||
import psutil
|
||||
import torch
|
||||
from deepspeed import comm as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||
from deepspeed.runtime.zenflow.zenflow_utils import start_optimizer_process
|
||||
from deepspeed.runtime.utils import (see_memory_usage)
|
||||
from deepspeed.ops.adam import ZenFlowSelectiveAdamW
|
||||
|
||||
@ -97,6 +94,8 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
||||
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
|
||||
self.offload_selective_optimizer = zenflow_config.offload
|
||||
self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
|
||||
self.start_optimizer_process = lambda: start_optimizer_process(self)
|
||||
self.zf_stage3 = False
|
||||
|
||||
if self.offload_selective_optimizer:
|
||||
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
|
||||
@ -636,64 +635,10 @@ class ZenFlowZeroOptimizerSequential(ZenFlowZeroOptimizer):
|
||||
self.optimizer.step(step_id=self.micro_step + 1)
|
||||
|
||||
|
||||
def disable_accelerator():
|
||||
accelerator = get_accelerator()
|
||||
accelerator.is_available = lambda: False
|
||||
accelerator.device_count = lambda: 0
|
||||
accelerator.current_device = lambda: -1
|
||||
# Optionally mark it as initialized if needed
|
||||
if hasattr(accelerator, "_initialized"):
|
||||
accelerator._initialized = True
|
||||
|
||||
|
||||
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
|
||||
shared_stale_param_map, zf_affinity):
|
||||
disable_accelerator()
|
||||
|
||||
current_process = psutil.Process()
|
||||
current_process.cpu_affinity(zf_affinity)
|
||||
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
|
||||
|
||||
from deepspeed.ops.adam import ZenFlowCPUAdam
|
||||
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
|
||||
|
||||
pipe.send({"type": "ready"})
|
||||
|
||||
# TODO: replace this with rpc
|
||||
|
||||
while True:
|
||||
cmd = pipe.recv()
|
||||
if cmd["type"] == "step":
|
||||
now_state = cmd["now_state"]
|
||||
micro_step = cmd["micro_step"]
|
||||
group_infos = cmd["group_infos"]
|
||||
|
||||
for group_no, group_info in enumerate(group_infos):
|
||||
original_param_groups = optimizer.param_groups
|
||||
optimizer.param_groups = [original_param_groups[group_no]]
|
||||
group = optimizer.param_groups[0]
|
||||
|
||||
for param_idx, param in enumerate(group["params"]):
|
||||
key = (group_no, param_idx)
|
||||
if key in shared_overlap_grad_map:
|
||||
param.overlap_grad = shared_overlap_grad_map[key]
|
||||
if key in shared_stale_param_map:
|
||||
param.stale_param = shared_stale_param_map[key]
|
||||
|
||||
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
|
||||
|
||||
optimizer.param_groups = original_param_groups
|
||||
|
||||
pipe.send({"type": "done"})
|
||||
elif cmd["type"] == "exit":
|
||||
break
|
||||
|
||||
|
||||
class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ZenFlowZeroOptimizerParallel, self).__init__(*args, **kwargs)
|
||||
self.process_pool = mp.Pool(1)
|
||||
self.process_optimizer_established = False
|
||||
self.first_update_round_after_warmup = True
|
||||
|
||||
@ -759,85 +704,6 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
|
||||
dest_tensor.copy_(src_tensor, non_blocking=True)
|
||||
param.grad = None #offload only
|
||||
|
||||
# check if all tensors in the list are equal to each other
|
||||
def all_tensors_equal(self, tensor_list):
|
||||
first_tensor = tensor_list[0]
|
||||
for tensor in tensor_list[1:]:
|
||||
if not torch.equal(first_tensor, tensor):
|
||||
return False
|
||||
return True
|
||||
|
||||
def start_optimizer_process(self):
|
||||
from multiprocessing import Pipe, get_context, Manager
|
||||
|
||||
ctx = get_context("spawn")
|
||||
self.parent_conn, self.child_conn = Pipe()
|
||||
|
||||
manager = Manager()
|
||||
self.shared_overlap_grad_map = manager.dict()
|
||||
self.shared_stale_param_map = manager.dict()
|
||||
|
||||
for group_no, group in enumerate(self.optimizer.param_groups):
|
||||
for param_idx, param in enumerate(group['params']):
|
||||
param.data.share_memory_()
|
||||
if not hasattr(param, 'stale_param'):
|
||||
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
|
||||
param.stale_param.data.share_memory_()
|
||||
key = (group_no, param_idx)
|
||||
self.shared_stale_param_map[key] = param.stale_param
|
||||
if param.overlap_grad is not None:
|
||||
param.overlap_grad[0].data.share_memory_()
|
||||
param.overlap_grad[1].data.share_memory_()
|
||||
key = (group_no, param_idx)
|
||||
self.shared_overlap_grad_map[key] = param.overlap_grad
|
||||
|
||||
param_groups_data = self.optimizer.param_groups
|
||||
curr_rank = dist.get_rank()
|
||||
total_rank = dist.get_world_size()
|
||||
|
||||
current_process = psutil.Process()
|
||||
current_affinity = current_process.cpu_affinity()
|
||||
all_affinities = [
|
||||
torch.zeros(len(current_affinity),
|
||||
dtype=type(current_affinity[0]),
|
||||
device=get_accelerator().current_device_name()) for _ in range(total_rank)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_affinities,
|
||||
torch.tensor(current_affinity,
|
||||
dtype=type(current_affinity[0]),
|
||||
device=get_accelerator().current_device_name()))
|
||||
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
|
||||
if self.all_tensors_equal(all_affinities):
|
||||
num_phy_cores = psutil.cpu_count(logical=False)
|
||||
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
|
||||
num_available_phy_cores = len(available_phy_cores)
|
||||
my_rank = curr_rank
|
||||
my_size = total_rank
|
||||
cores_per_rank = num_available_phy_cores // my_size
|
||||
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
|
||||
pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity))
|
||||
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
|
||||
zf_affinity = current_affinity[pt_num_cores:]
|
||||
pt_affinity = current_affinity[:pt_num_cores]
|
||||
else:
|
||||
zf_affinity = current_affinity
|
||||
pt_affinity = current_affinity
|
||||
self.process = ctx.Process(
|
||||
target=zenflow_optimizer_process,
|
||||
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
|
||||
self.shared_stale_param_map, zf_affinity),
|
||||
)
|
||||
self.process.daemon = True
|
||||
self.process.start()
|
||||
current_process.cpu_affinity(pt_affinity)
|
||||
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
|
||||
|
||||
msg = self.parent_conn.recv()
|
||||
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
|
||||
|
||||
self.process_optimizer_established = True
|
||||
|
||||
def wait_last_update_and_copy(self):
|
||||
|
||||
if not hasattr(self, 'parent_conn'):
|
||||
|
@ -3,7 +3,12 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import psutil
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
def _flatten_dense_tensors(tensors):
|
||||
@ -40,3 +45,147 @@ def _unflatten_dense_tensors(flat, tensors):
|
||||
transposed_tensors = [t.transpose(0, 1) if t.dim() == 2 else t for t in tensors]
|
||||
unflat = torch._C._nn.unflatten_dense_tensors(flat, transposed_tensors)
|
||||
return [t.transpose(0, 1) if t.dim() == 2 else t for t in unflat]
|
||||
|
||||
|
||||
def disable_accelerator():
|
||||
accelerator = get_accelerator()
|
||||
accelerator.is_available = lambda: False
|
||||
accelerator.device_count = lambda: 0
|
||||
accelerator.current_device = lambda: -1
|
||||
# Optionally mark it as initialized if needed
|
||||
if hasattr(accelerator, "_initialized"):
|
||||
accelerator._initialized = True
|
||||
|
||||
|
||||
def zenflow_optimizer_process(pipe, param_groups, shared_overlap_grad_map, shared_stale_param_map, zf_affinity):
|
||||
disable_accelerator()
|
||||
|
||||
current_process = psutil.Process()
|
||||
current_process.cpu_affinity(zf_affinity)
|
||||
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
|
||||
|
||||
from deepspeed.ops.adam import ZenFlowCPUAdam
|
||||
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
|
||||
|
||||
pipe.send({"type": "ready"})
|
||||
|
||||
# TODO: replace this with rpc
|
||||
|
||||
while True:
|
||||
cmd = pipe.recv()
|
||||
if cmd["type"] == "step":
|
||||
now_state = cmd["now_state"]
|
||||
micro_step = cmd["micro_step"]
|
||||
group_infos = cmd["group_infos"]
|
||||
|
||||
for group_no, group_info in enumerate(group_infos):
|
||||
original_param_groups = optimizer.param_groups
|
||||
optimizer.param_groups = [original_param_groups[group_no]]
|
||||
group = optimizer.param_groups[0]
|
||||
|
||||
for param_idx, param in enumerate(group["params"]):
|
||||
key = (group_no, param_idx)
|
||||
if key in shared_overlap_grad_map:
|
||||
param.overlap_grad = shared_overlap_grad_map[key]
|
||||
if key in shared_stale_param_map:
|
||||
param.stale_param = shared_stale_param_map[key]
|
||||
|
||||
optimizer.step(step_id=micro_step + 1, now_state=now_state, group_info=group_info)
|
||||
|
||||
optimizer.param_groups = original_param_groups
|
||||
|
||||
pipe.send({"type": "done"})
|
||||
elif cmd["type"] == "exit":
|
||||
break
|
||||
|
||||
|
||||
def all_tensors_equal(tensor_list):
|
||||
first_tensor = tensor_list[0]
|
||||
for tensor in tensor_list[1:]:
|
||||
if not torch.equal(first_tensor, tensor):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def start_optimizer_process(zf_optimizer):
|
||||
from multiprocessing import Pipe, get_context, Manager
|
||||
|
||||
ctx = get_context("spawn")
|
||||
zf_optimizer.parent_conn, zf_optimizer.child_conn = Pipe()
|
||||
|
||||
manager = Manager()
|
||||
zf_optimizer.shared_overlap_grad_map = manager.dict()
|
||||
zf_optimizer.shared_stale_param_map = manager.dict()
|
||||
|
||||
if zf_optimizer.zf_stage3:
|
||||
params_iter = [((group_no, 0), param)
|
||||
for group_no, param in enumerate(zf_optimizer.fp32_partitioned_groups_flat)]
|
||||
else:
|
||||
params_iter = [((group_no, param_idx), param)
|
||||
for group_no, group in enumerate(zf_optimizer.optimizer.param_groups)
|
||||
for param_idx, param in enumerate(group["params"])]
|
||||
|
||||
for key, param in params_iter:
|
||||
param.data.share_memory_()
|
||||
|
||||
if not hasattr(param, "stale_param"):
|
||||
param.stale_param = torch.zeros_like(param.data, dtype=param.dtype, device=param.device)
|
||||
param.stale_param.data.share_memory_()
|
||||
zf_optimizer.shared_stale_param_map[key] = param.stale_param
|
||||
|
||||
if getattr(param, "overlap_grad", None) is not None:
|
||||
param.overlap_grad[0].data.share_memory_()
|
||||
param.overlap_grad[1].data.share_memory_()
|
||||
zf_optimizer.shared_overlap_grad_map[key] = param.overlap_grad
|
||||
|
||||
param_groups_data = ([{
|
||||
"params": [param]
|
||||
} for param in zf_optimizer.fp32_partitioned_groups_flat]
|
||||
if zf_optimizer.zf_stage3 else zf_optimizer.optimizer.param_groups)
|
||||
|
||||
curr_rank = dist.get_rank()
|
||||
total_rank = dist.get_world_size()
|
||||
|
||||
current_process = psutil.Process()
|
||||
current_affinity = current_process.cpu_affinity()
|
||||
all_affinities = [
|
||||
torch.zeros(len(current_affinity),
|
||||
dtype=type(current_affinity[0]),
|
||||
device=get_accelerator().current_device_name()) for _ in range(total_rank)
|
||||
]
|
||||
dist.all_gather(
|
||||
all_affinities,
|
||||
torch.tensor(current_affinity, dtype=type(current_affinity[0]),
|
||||
device=get_accelerator().current_device_name()))
|
||||
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
|
||||
if all_tensors_equal(all_affinities):
|
||||
num_phy_cores = psutil.cpu_count(logical=False)
|
||||
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
|
||||
num_available_phy_cores = len(available_phy_cores)
|
||||
my_rank = curr_rank
|
||||
my_size = total_rank
|
||||
cores_per_rank = num_available_phy_cores // my_size
|
||||
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
|
||||
pt_num_cores = math.ceil(zf_optimizer.pt_reserved_cores_perc * len(current_affinity))
|
||||
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
|
||||
zf_affinity = current_affinity[pt_num_cores:]
|
||||
pt_affinity = current_affinity[:pt_num_cores]
|
||||
else:
|
||||
zf_affinity = current_affinity
|
||||
pt_affinity = current_affinity
|
||||
|
||||
zf_optimizer.process = ctx.Process(
|
||||
target=zenflow_optimizer_process,
|
||||
args=(zf_optimizer.child_conn, param_groups_data, zf_optimizer.shared_overlap_grad_map,
|
||||
zf_optimizer.shared_stale_param_map, zf_affinity),
|
||||
)
|
||||
zf_optimizer.process.daemon = True
|
||||
zf_optimizer.process.start()
|
||||
|
||||
current_process.cpu_affinity(pt_affinity)
|
||||
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
|
||||
|
||||
msg = zf_optimizer.parent_conn.recv()
|
||||
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
|
||||
|
||||
zf_optimizer.process_optimizer_established = True
|
||||
|
@ -93,6 +93,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
module,
|
||||
timers,
|
||||
ds_config,
|
||||
zenflow=False,
|
||||
overlap_comm=True,
|
||||
prefetch_bucket_size=50000000,
|
||||
max_reuse_distance=1000000000,
|
||||
@ -115,6 +116,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
|
||||
self.module = module
|
||||
self.timers = timers
|
||||
self.zenflow = zenflow
|
||||
self.dtype = list(module.parameters())[0].dtype
|
||||
self.dp_process_group = dp_process_group
|
||||
self.offload_device = None
|
||||
@ -472,6 +474,11 @@ class DeepSpeedZeRoOffload(object):
|
||||
param_coordinator.record_module(sub_module)
|
||||
param_coordinator.fetch_sub_module(sub_module, forward=True)
|
||||
|
||||
if self.zenflow:
|
||||
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
|
||||
for param in params_to_fetch:
|
||||
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
|
||||
|
||||
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
|
||||
|
||||
@torch.no_grad()
|
||||
@ -480,6 +487,11 @@ class DeepSpeedZeRoOffload(object):
|
||||
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
|
||||
force=False)
|
||||
|
||||
if self.zenflow:
|
||||
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
|
||||
for param in params_to_fetch:
|
||||
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
|
||||
|
||||
param_coordinator = self.get_param_coordinator()
|
||||
param_coordinator.release_sub_module(sub_module, forward=True)
|
||||
|
||||
@ -496,6 +508,11 @@ class DeepSpeedZeRoOffload(object):
|
||||
param_coordinator.record_module(sub_module)
|
||||
param_coordinator.fetch_sub_module(sub_module, forward=False)
|
||||
|
||||
if self.zenflow:
|
||||
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
|
||||
for param in params_to_fetch:
|
||||
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
|
||||
|
||||
@torch.no_grad()
|
||||
def post_sub_module_backward_function(self, sub_module):
|
||||
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
|
||||
@ -503,6 +520,11 @@ class DeepSpeedZeRoOffload(object):
|
||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
|
||||
force=False)
|
||||
|
||||
if self.zenflow:
|
||||
params_to_fetch = set(iter_params(sub_module, recurse=z3_leaf_module(sub_module)))
|
||||
for param in params_to_fetch:
|
||||
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data
|
||||
|
||||
self.get_param_coordinator().release_sub_module(sub_module, forward=False)
|
||||
|
||||
see_memory_usage(
|
||||
|
@ -26,6 +26,7 @@ from deepspeed.runtime.zero.partition_parameters import *
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
|
||||
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
|
||||
import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3
|
||||
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
|
||||
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
@ -160,6 +161,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
overlap_comm=False,
|
||||
offload_optimizer_config=None,
|
||||
offload_param_config=None,
|
||||
zenflow_config=None,
|
||||
sub_group_size=1000000000000,
|
||||
offload_ratio=0.0,
|
||||
mpu=None,
|
||||
@ -226,6 +228,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.partial_offload = offload_ratio
|
||||
self.enable_sanity_checks = enable_sanity_checks
|
||||
|
||||
self.create_zenflow_hooks()
|
||||
self._initialize_zenflow_stage3_prologue(module, zenflow_config)
|
||||
|
||||
#num of ranks in a ZeRO param partitioning group
|
||||
self.zero_hpz_partition_size = zero_hpz_partition_size
|
||||
|
||||
@ -241,6 +246,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
module=module,
|
||||
timers=timers,
|
||||
ds_config=ds_config,
|
||||
zenflow=self.zenflow,
|
||||
overlap_comm=overlap_comm,
|
||||
prefetch_bucket_size=prefetch_bucket_size,
|
||||
max_reuse_distance=max_reuse_distance,
|
||||
@ -276,6 +282,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
for i in range(1, len(self.optimizer.param_groups)):
|
||||
self.backup_optimizer.add_param_group(self.optimizer.param_groups[i])
|
||||
|
||||
self._initialize_zenflow_stage3_epilogue(zenflow_config, overlap_comm)
|
||||
|
||||
self.module = module
|
||||
self.elastic_checkpoint = elastic_checkpoint
|
||||
|
||||
@ -476,11 +484,32 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
print_rank_0("Removed grad acc hooks", force=False)
|
||||
self.ipg_buckets.clear()
|
||||
|
||||
def create_zenflow_hooks(self):
|
||||
from functools import partial
|
||||
hook_names = [
|
||||
"_initialize_zenflow_stage3_prologue",
|
||||
"_initialize_zenflow_stage3_epilogue",
|
||||
"zenflow_cpu_optimizer_step",
|
||||
"_sync_selective_optimizer_lr",
|
||||
"selective_optimizer_step",
|
||||
"is_zenflow_select_boundary",
|
||||
"update_selected_channels",
|
||||
"_process_selected_fp32_groups_grad",
|
||||
"zenflow_backward_prologue",
|
||||
"zenflow_backward_epilogue",
|
||||
"log_selective_optimizer_timers",
|
||||
]
|
||||
|
||||
for name in hook_names:
|
||||
fn = getattr(zf_engine_stage3, name)
|
||||
setattr(self, name, partial(fn, self))
|
||||
|
||||
def initialize_ds_offload(
|
||||
self,
|
||||
module,
|
||||
timers,
|
||||
ds_config,
|
||||
zenflow,
|
||||
overlap_comm,
|
||||
prefetch_bucket_size,
|
||||
max_reuse_distance,
|
||||
@ -499,6 +528,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
return DeepSpeedZeRoOffload(module=module,
|
||||
timers=timers,
|
||||
ds_config=ds_config,
|
||||
zenflow=zenflow,
|
||||
overlap_comm=overlap_comm,
|
||||
prefetch_bucket_size=prefetch_bucket_size,
|
||||
max_reuse_distance=max_reuse_distance,
|
||||
@ -738,6 +768,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.fp16_groups.append(sub_group)
|
||||
self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group])
|
||||
|
||||
if self.zenflow:
|
||||
for param in sub_group:
|
||||
param.group_id = param_group_idx
|
||||
|
||||
# record sub group -> group mapping
|
||||
self.sub_group_to_group_id[sub_group_idx] = param_group_idx
|
||||
|
||||
@ -997,7 +1031,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.torch_autocast_gradscaler.step(optimizer)
|
||||
self.torch_autocast_gradscaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
if not self.zenflow:
|
||||
optimizer.step()
|
||||
else:
|
||||
self.zenflow_cpu_optimizer_step()
|
||||
|
||||
if self.offload_optimizer:
|
||||
cur_device = self.subgroup_to_device[sub_group_id]
|
||||
@ -1267,8 +1304,14 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
# move the gradient to a contiguous buffer
|
||||
with get_accelerator().stream(self.reduce_and_partition_stream):
|
||||
# move the parameter's gradient to the contiguous flat buffer
|
||||
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
|
||||
new_grad_tensor.copy_(param.grad, non_blocking=True)
|
||||
if self.zenflow and len(param.ds_shape) != 1:
|
||||
transposed_shape = param.grad.t().shape
|
||||
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements,
|
||||
param.grad.numel()).view(transposed_shape)
|
||||
new_grad_tensor.copy_(param.grad.t().contiguous(), non_blocking=True)
|
||||
else:
|
||||
new_grad_tensor = bucket.buffer.narrow(0, bucket.elements, param.grad.numel()).view_as(param.grad)
|
||||
new_grad_tensor.copy_(param.grad, non_blocking=True)
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
param.grad.data = new_grad_tensor
|
||||
@ -1308,6 +1351,12 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
params_in_bucket.sort(key=lambda p: p.ds_id)
|
||||
grad_partitions = self.__avg_scatter_grads(params_in_bucket, communication_data_type)
|
||||
|
||||
if self.is_zenflow_select_boundary():
|
||||
self.update_selected_channels(params_in_bucket, grad_partitions)
|
||||
|
||||
if self.zenflow and self.micro_step >= self.full_warm_up_rounds:
|
||||
self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions)
|
||||
|
||||
self.partition_grads(params_in_bucket, grad_partitions)
|
||||
|
||||
params_in_bucket.clear()
|
||||
@ -2272,6 +2321,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
if self.swap_optimizer:
|
||||
self.optimizer_swapper.pre_backward()
|
||||
|
||||
if self.zenflow:
|
||||
self.zenflow_backward_prologue()
|
||||
|
||||
see_memory_usage("Before backward", force=False)
|
||||
|
||||
if self.custom_loss_scaler:
|
||||
@ -2282,6 +2334,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
else:
|
||||
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
||||
|
||||
if self.zenflow:
|
||||
self.zenflow_backward_epilogue()
|
||||
|
||||
if self.swap_optimizer:
|
||||
self.optimizer_swapper.post_backward()
|
||||
|
||||
|
@ -3,40 +3,64 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import Parameter
|
||||
from deepspeed.ops.adam import ZenFlowSelectiveAdamW
|
||||
from deepspeed.ops.adam import ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3
|
||||
|
||||
|
||||
def make_param(shape, selected_indices=None):
|
||||
def make_param(Opt, shape, selected_indices=None):
|
||||
param = Parameter(torch.randn(*shape))
|
||||
|
||||
if Opt is ZenFlowSelectiveAdamW_stage3:
|
||||
if param.dim() == 2:
|
||||
param.ds_shape = (param.shape[1], param.shape[0])
|
||||
param.ds_tensor = param.clone().T.contiguous().view(-1)
|
||||
else:
|
||||
param.ds_shape = tuple(param.shape)
|
||||
param.ds_tensor = param.clone()
|
||||
|
||||
param.complete_column_offset = 0
|
||||
param.complete_numel = param.numel()
|
||||
param.group_id = 0
|
||||
|
||||
if selected_indices is not None:
|
||||
param.selected_indices = selected_indices
|
||||
param.selected_grad = torch.randn(param.shape[0], len(selected_indices))
|
||||
param.temp_selected_param = param.data[:, selected_indices].clone()
|
||||
if param.dim() == 2:
|
||||
param.selected_grad = torch.randn(
|
||||
param.shape[0], len(selected_indices)) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(
|
||||
len(selected_indices), param.ds_shape[1])
|
||||
param.temp_selected_param = param.data[:, selected_indices].clone(
|
||||
) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
|
||||
param.ds_shape)[selected_indices, :].clone()
|
||||
else:
|
||||
param.selected_grad = torch.randn_like(param.data)
|
||||
param.temp_selected_param = param.data.clone()
|
||||
return param
|
||||
|
||||
|
||||
def test_init_methods():
|
||||
opt1 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_init_methods(Opt):
|
||||
opt1 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=False)
|
||||
assert opt1.step == opt1._step_without_offload
|
||||
assert opt1.group_step == opt1._group_step_without_offload
|
||||
opt2 = ZenFlowSelectiveAdamW([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True)
|
||||
opt2 = Opt([torch.nn.Parameter(torch.randn(2, 4))], lr=1e-3, offload=True)
|
||||
assert opt2.step == opt2._step_with_offload
|
||||
assert opt2.group_step == opt2._group_step_with_offload
|
||||
|
||||
|
||||
def test_step_without_offload():
|
||||
param = make_param((4, 6), torch.tensor([1, 3, 4]))
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_step_without_offload(Opt):
|
||||
param = make_param(Opt, (4, 6), torch.tensor([1, 3, 4]))
|
||||
param.requires_grad_(True)
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
|
||||
old_selected = param.data[:, param.selected_indices].clone()
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
|
||||
old_selected = param.data[:, param.selected_indices].clone(
|
||||
) if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
|
||||
param.ds_shape)[param.selected_indices, :].clone()
|
||||
opt.step()
|
||||
|
||||
new_selected = param.data[:, param.selected_indices]
|
||||
new_selected = param.data[:, param.
|
||||
selected_indices] if Opt is not ZenFlowSelectiveAdamW_stage3 else param.ds_tensor.view(
|
||||
param.ds_shape)[param.selected_indices, :]
|
||||
diff_norm = (old_selected - new_selected).abs().sum().item()
|
||||
|
||||
assert diff_norm > 1e-5, "param was not updated"
|
||||
@ -44,9 +68,10 @@ def test_step_without_offload():
|
||||
assert param.selected_grad is None
|
||||
|
||||
|
||||
def test_step_with_offload_bucket_flush():
|
||||
param1 = make_param((2, 4), torch.tensor([1, 2]))
|
||||
param2 = make_param((2, 4), torch.tensor([0, 3]))
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_step_with_offload_bucket_flush(Opt):
|
||||
param1 = make_param(Opt, (2, 4), torch.tensor([1, 2]))
|
||||
param2 = make_param(Opt, (2, 4), torch.tensor([0, 3]))
|
||||
|
||||
param1.exp_avg = torch.zeros_like(param1.temp_selected_param)
|
||||
param1.exp_avg_sq = torch.zeros_like(param1.temp_selected_param)
|
||||
@ -58,15 +83,16 @@ def test_step_with_offload_bucket_flush():
|
||||
param2.exp_avg_cpu_data = param2.exp_avg.clone().cpu()
|
||||
param2.exp_avg_sq_cpu_data = param2.exp_avg_sq.clone().cpu()
|
||||
|
||||
opt = ZenFlowSelectiveAdamW([param1, param2], lr=1e-3, offload=True, bucket_size=1)
|
||||
opt = Opt([param1, param2], lr=1e-3, offload=True, bucket_size=1)
|
||||
opt.step()
|
||||
assert param1.temp_selected_param is None
|
||||
assert param2.temp_selected_param is None
|
||||
|
||||
|
||||
def test_clear_selected_mv():
|
||||
param = make_param((2, 4), torch.tensor([0, 2]))
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_clear_selected_mv(Opt):
|
||||
param = make_param(Opt, (2, 4), torch.tensor([0, 2]))
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
opt.step()
|
||||
state = opt.state[param]
|
||||
assert "exp_avg" in state
|
||||
@ -74,17 +100,19 @@ def test_clear_selected_mv():
|
||||
assert state["exp_avg"].abs().sum() == 0
|
||||
|
||||
|
||||
def test_group_step_without_offload():
|
||||
param = make_param((2, 6), torch.tensor([0, 1, 3]))
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
group_to_paramlist = {0: [param]}
|
||||
opt._group_step_without_offload(group_to_paramlist)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_group_step_without_offload(Opt):
|
||||
param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3]))
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
group_to_paramlist = {0: [param]} if not Opt is ZenFlowSelectiveAdamW_stage3 else [param]
|
||||
opt.group_step(group_to_paramlist)
|
||||
assert param.selected_grad is None
|
||||
|
||||
|
||||
def test_group_step_with_offload():
|
||||
param = make_param((2, 6), torch.tensor([0, 1, 3]))
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=True)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_group_step_with_offload(Opt):
|
||||
param = make_param(Opt, (2, 6), torch.tensor([0, 1, 3]))
|
||||
opt = Opt([param], lr=1e-3, offload=True)
|
||||
|
||||
state = opt.state.setdefault(param, {})
|
||||
state["step"] = torch.zeros((), dtype=param.dtype, device=param.device)
|
||||
@ -93,33 +121,30 @@ def test_group_step_with_offload():
|
||||
param.exp_avg_cpu_data = param.exp_avg.clone().cpu()
|
||||
param.exp_avg_sq_cpu_data = param.exp_avg_sq.clone().cpu()
|
||||
|
||||
group_to_paramlist = {0: [param]}
|
||||
opt._group_step_with_offload(group_to_paramlist)
|
||||
group_to_paramlist = {0: [param]} if Opt is not ZenFlowSelectiveAdamW_stage3 else [param]
|
||||
opt.group_step(group_to_paramlist)
|
||||
assert param.selected_grad is None
|
||||
|
||||
|
||||
def test_1d_param_support():
|
||||
param = Parameter(torch.randn(10))
|
||||
param.selected_grad = torch.randn(10)
|
||||
param.temp_selected_param = param.data.clone()
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_1d_param_support(Opt):
|
||||
param = make_param(Opt, (10, ), torch.arange(10))
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
opt.step()
|
||||
assert param.temp_selected_param is None
|
||||
assert param.selected_grad is None
|
||||
|
||||
|
||||
def test_state_increment():
|
||||
param = torch.nn.Parameter(torch.randn(2, 4))
|
||||
param.selected_indices = torch.arange(4)
|
||||
param.selected_grad = torch.randn(2, 4)
|
||||
param.temp_selected_param = param.data.clone()
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_state_increment(Opt):
|
||||
param = make_param(Opt, (2, 4), torch.arange(4))
|
||||
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
opt.step()
|
||||
step1 = opt.state[param]['step'].item()
|
||||
|
||||
param.selected_grad = torch.randn(2, 4)
|
||||
param.temp_selected_param = param.data.clone()
|
||||
param.selected_grad = torch.randn(2, 4) if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2)
|
||||
param.temp_selected_param = param.data.clone() if Opt is not ZenFlowSelectiveAdamW_stage3 else torch.randn(4, 2)
|
||||
param.selected_indices = torch.arange(4)
|
||||
|
||||
opt.step()
|
||||
@ -134,22 +159,29 @@ def _compare_with_torch_adamw(param, zenflow_opt, atol=1e-4):
|
||||
for _ in range(10):
|
||||
grad = torch.randn_like(param)
|
||||
param.selected_indices = torch.arange(param.shape[1])
|
||||
param.selected_grad = grad
|
||||
param.temp_selected_param = param.data.clone()
|
||||
param.selected_grad = grad if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3) else grad.T
|
||||
param.temp_selected_param = param.data.clone() if not isinstance(
|
||||
zenflow_opt, ZenFlowSelectiveAdamW_stage3) else param.ds_tensor.view(param.ds_shape).clone()
|
||||
|
||||
torch_param.grad = grad.clone()
|
||||
|
||||
zenflow_opt.step()
|
||||
torch_opt.step()
|
||||
|
||||
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
|
||||
param.data.cpu().numpy(),
|
||||
atol=atol,
|
||||
err_msg="Mismatch with torch.AdamW")
|
||||
if not isinstance(zenflow_opt, ZenFlowSelectiveAdamW_stage3):
|
||||
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
|
||||
param.data.cpu().numpy(),
|
||||
atol=atol,
|
||||
err_msg="Mismatch with torch.AdamW")
|
||||
else:
|
||||
np.testing.assert_allclose(torch_param.data.cpu().numpy(),
|
||||
param.ds_tensor.view(param.ds_shape).T.clone().data.cpu().numpy(),
|
||||
atol=atol,
|
||||
err_msg="Mismatch with torch.AdamW")
|
||||
|
||||
|
||||
def test_against_torch_adamw():
|
||||
param = torch.nn.Parameter(torch.randn(2, 4))
|
||||
param.selected_indices = torch.arange(4)
|
||||
opt = ZenFlowSelectiveAdamW([param], lr=1e-3, offload=False)
|
||||
@pytest.mark.parametrize("Opt", [ZenFlowSelectiveAdamW, ZenFlowSelectiveAdamW_stage3])
|
||||
def test_against_torch_adamw(Opt):
|
||||
param = make_param(Opt, (2, 4), torch.arange(4))
|
||||
opt = Opt([param], lr=1e-3, offload=False)
|
||||
_compare_with_torch_adamw(param, opt)
|
||||
|
@ -74,7 +74,7 @@ class BaseZenFlowTest:
|
||||
model.destroy()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@pytest.mark.parametrize("stage", [1, 2, 3])
|
||||
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
|
||||
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
|
||||
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [
|
||||
@ -93,7 +93,7 @@ class TestZenFlowSingleGPU(DistributedTest, BaseZenFlowTest):
|
||||
tester.run_training_distributed(config_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@pytest.mark.parametrize("stage", [1, 2, 3])
|
||||
@pytest.mark.parametrize("full_warm_up_rounds", [0, 3])
|
||||
@pytest.mark.parametrize("offload_selective_optimizer", [True, False])
|
||||
@pytest.mark.parametrize("select_strategy,select_interval,update_interval", [
|
||||
|
Reference in New Issue
Block a user