Autotune ZenFlow affinity (#7506)

This PR address the following ZenFlow optimizer core binding issue.
https://github.com/deepspeedai/DeepSpeed/issues/7478

With this PR, ZenFlow optimizer worker would derive its core binding
from deepspeed core binding mechanism. The algorithm is as following:
1. Each DeepSpeed rank get its core binding by using DeepSpeed command
line `--bind_cores_to_rank`, this command would assign each CPU physical
cores to different workers
2. When spawing ZenFlow optimizer worker, DeepSpeed would split current
CPU affinity list into two sublist: pt_affinity and zf_affinity
3. zf_affinity would be used to set affinity of ZenFlow optimizer
worker. pt_affinity would be used to set current pytorch process.
4. By default, one cores is reserved by each pytorch process, the rest
is used by ZenFlow optimizer worker. The number of cores reserved for
pytorch process can be changed by ZenFlow config variable:
`pt_reserved_cores`

---------

Signed-off-by: Guokai Ma <guokai.ma@gmail.com>
Signed-off-by: Ma, Guokai <guokai.ma@intel.com>
Signed-off-by: aeeeeeep <aeeeeeep@proton.me>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: aeeeeeep <aeeeeeep@proton.me>
Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
Co-authored-by: Zhipeng Wang <zwanga@wustl.edu>
Co-authored-by: Peng Du <pedu@linkedin.com>
Co-authored-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Ma, Guokai
2025-09-04 19:10:39 +08:00
committed by GitHub
parent 66bf2a642d
commit 43537d0a60
2 changed files with 53 additions and 28 deletions

View File

@ -39,6 +39,10 @@ class ZenFlowConfig(DeepSpeedConfigModel):
full_warm_up_rounds: int = 0 full_warm_up_rounds: int = 0
"""Number of initial rounds during which all gradients are fully updated (no selection).""" """Number of initial rounds during which all gradients are fully updated (no selection)."""
pt_reserved_cores_perc: float = Field(0.5, ge=0.0, le=1.0)
"""Number of cores reserved for pytorch threads,
the remaining cores will be used by zenflow optimizer workers"""
steps_per_epoch: Optional[int] = Field( steps_per_epoch: Optional[int] = Field(
default=None, default=None,
description= description=
@ -59,4 +63,7 @@ class ZenFlowConfig(DeepSpeedConfigModel):
if not isinstance(self.full_warm_up_rounds, int): if not isinstance(self.full_warm_up_rounds, int):
raise ValueError('full_warm_up_rounds must be an integer') raise ValueError('full_warm_up_rounds must be an integer')
if not isinstance(self.pt_reserved_cores_perc, float):
raise ValueError('pt_reserved_cores_perc must be a float')
return self return self

View File

@ -4,6 +4,8 @@
# DeepSpeed Team # DeepSpeed Team
import os import os
import math
import psutil
import torch import torch
from deepspeed import comm as dist from deepspeed import comm as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -94,6 +96,7 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
self.micro_step = -1 self.micro_step = -1
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
self.offload_selective_optimizer = zenflow_config.offload self.offload_selective_optimizer = zenflow_config.offload
self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
if self.offload_selective_optimizer: if self.offload_selective_optimizer:
assert overlap_comm, "offload selective optimizer should be used with overlap_comm" assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
@ -192,7 +195,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name()) self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name())
# count = 0
bucket = self.ipg_buckets[communication_data_type] bucket = self.ipg_buckets[communication_data_type]
for i, param_idx_in_group, param_id in bucket.params: for i, param_idx_in_group, param_id in bucket.params:
param = self.bit16_groups[i][param_idx_in_group] param = self.bit16_groups[i][param_idx_in_group]
@ -309,7 +311,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
group_to_paramlist = {} group_to_paramlist = {}
# count = 0
for i, param_idx_in_group, param_id in bucket.params: for i, param_idx_in_group, param_id in bucket.params:
param = self.bit16_groups[i][param_idx_in_group] param = self.bit16_groups[i][param_idx_in_group]
@ -478,7 +479,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
curr_selected_reduce_size = 0 curr_selected_reduce_size = 0
process_group = self.dp_process_group process_group = self.dp_process_group
# count = 0
bucket = self.ipg_buckets[communication_data_type] bucket = self.ipg_buckets[communication_data_type]
for i, param_idx_in_group, param_id in bucket.params: for i, param_idx_in_group, param_id in bucket.params:
param = self.bit16_groups[i][param_idx_in_group] param = self.bit16_groups[i][param_idx_in_group]
@ -507,10 +507,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
for idx in range(len(partition_ids_w_offsets)): for idx in range(len(partition_ids_w_offsets)):
partition_id, offset = partition_ids_w_offsets[idx] partition_id, offset = partition_ids_w_offsets[idx]
# if dist.get_rank() == 0 and count < 100:
# print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
# count += 1
# Calculate numel for grad slice depending on partition location # Calculate numel for grad slice depending on partition location
if idx == len(partition_ids_w_offsets) - 1: if idx == len(partition_ids_w_offsets) - 1:
# Last partition_id uses its own offset # Last partition_id uses its own offset
@ -563,7 +559,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
if self.is_zenflow_select_boundary(): if self.is_zenflow_select_boundary():
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start() self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start()
# print("update selected")
self.update_selected_channels(tensor, curr_column_size, communication_data_type) self.update_selected_channels(tensor, curr_column_size, communication_data_type)
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop() self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop()
elif self.zenflow: elif self.zenflow:
@ -652,27 +647,12 @@ def disable_accelerator():
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map, def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
shared_stale_param_map): shared_stale_param_map, zf_affinity):
disable_accelerator() disable_accelerator()
TOTAL_CORES = os.cpu_count() current_process = psutil.Process()
CPUADAM_CORE_START = 0 current_process.cpu_affinity(zf_affinity)
CPUADAM_CORE_END = TOTAL_CORES os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
TOTAL_CORES = CPUADAM_CORE_END - CPUADAM_CORE_START
cores_per_rank = TOTAL_CORES // total_rank
extra = TOTAL_CORES % total_rank
start_offset = curr_rank * cores_per_rank + min(curr_rank, extra)
end_offset = start_offset + cores_per_rank + (1 if curr_rank < extra else 0)
assigned_cores = set(range(CPUADAM_CORE_START + start_offset, CPUADAM_CORE_START + end_offset))
try:
os.sched_setaffinity(0, assigned_cores)
print(f"[Optimizer Thread] Rank {curr_rank} bound to CPU cores: {os.sched_getaffinity(0)}", flush=True)
except AttributeError:
print("[Optimizer Thread] sched_setaffinity not supported on this system.")
except Exception as e:
print(f"[Optimizer Thread] Failed to set affinity: {e}")
from deepspeed.ops.adam import ZenFlowCPUAdam from deepspeed.ops.adam import ZenFlowCPUAdam
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True) optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
@ -779,6 +759,14 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
dest_tensor.copy_(src_tensor, non_blocking=True) dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad = None #offload only param.grad = None #offload only
# check if all tensors in the list are equal to each other
def all_tensors_equal(self, tensor_list):
first_tensor = tensor_list[0]
for tensor in tensor_list[1:]:
if not torch.equal(first_tensor, tensor):
return False
return True
def start_optimizer_process(self): def start_optimizer_process(self):
from multiprocessing import Pipe, get_context, Manager from multiprocessing import Pipe, get_context, Manager
@ -807,13 +795,43 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
curr_rank = dist.get_rank() curr_rank = dist.get_rank()
total_rank = dist.get_world_size() total_rank = dist.get_world_size()
current_process = psutil.Process()
current_affinity = current_process.cpu_affinity()
all_affinities = [
torch.zeros(len(current_affinity),
dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()) for _ in range(total_rank)
]
dist.all_gather(
all_affinities,
torch.tensor(current_affinity,
dtype=type(current_affinity[0]),
device=get_accelerator().current_device_name()))
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
if self.all_tensors_equal(all_affinities):
num_phy_cores = psutil.cpu_count(logical=False)
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
num_available_phy_cores = len(available_phy_cores)
my_rank = curr_rank
my_size = total_rank
cores_per_rank = num_available_phy_cores // my_size
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity))
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
zf_affinity = current_affinity[pt_num_cores:]
pt_affinity = current_affinity[:pt_num_cores]
else:
zf_affinity = current_affinity
pt_affinity = current_affinity
self.process = ctx.Process( self.process = ctx.Process(
target=zenflow_optimizer_process, target=zenflow_optimizer_process,
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map, args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
self.shared_stale_param_map), self.shared_stale_param_map, zf_affinity),
) )
self.process.daemon = True self.process.daemon = True
self.process.start() self.process.start()
current_process.cpu_affinity(pt_affinity)
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
msg = self.parent_conn.recv() msg = self.parent_conn.recv()
assert msg["type"] == "ready", "Optimizer process did not initialize correctly." assert msg["type"] == "ready", "Optimizer process did not initialize correctly."