mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Autotune ZenFlow affinity (#7506)
This PR address the following ZenFlow optimizer core binding issue. https://github.com/deepspeedai/DeepSpeed/issues/7478 With this PR, ZenFlow optimizer worker would derive its core binding from deepspeed core binding mechanism. The algorithm is as following: 1. Each DeepSpeed rank get its core binding by using DeepSpeed command line `--bind_cores_to_rank`, this command would assign each CPU physical cores to different workers 2. When spawing ZenFlow optimizer worker, DeepSpeed would split current CPU affinity list into two sublist: pt_affinity and zf_affinity 3. zf_affinity would be used to set affinity of ZenFlow optimizer worker. pt_affinity would be used to set current pytorch process. 4. By default, one cores is reserved by each pytorch process, the rest is used by ZenFlow optimizer worker. The number of cores reserved for pytorch process can be changed by ZenFlow config variable: `pt_reserved_cores` --------- Signed-off-by: Guokai Ma <guokai.ma@gmail.com> Signed-off-by: Ma, Guokai <guokai.ma@intel.com> Signed-off-by: aeeeeeep <aeeeeeep@proton.me> Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: aeeeeeep <aeeeeeep@proton.me> Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com> Co-authored-by: Zhipeng Wang <zwanga@wustl.edu> Co-authored-by: Peng Du <pedu@linkedin.com> Co-authored-by: pengdurice <pengduhit@gmail.com> Co-authored-by: Zhipeng Wang <zhipengbayern@gmail.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
@ -39,6 +39,10 @@ class ZenFlowConfig(DeepSpeedConfigModel):
|
|||||||
full_warm_up_rounds: int = 0
|
full_warm_up_rounds: int = 0
|
||||||
"""Number of initial rounds during which all gradients are fully updated (no selection)."""
|
"""Number of initial rounds during which all gradients are fully updated (no selection)."""
|
||||||
|
|
||||||
|
pt_reserved_cores_perc: float = Field(0.5, ge=0.0, le=1.0)
|
||||||
|
"""Number of cores reserved for pytorch threads,
|
||||||
|
the remaining cores will be used by zenflow optimizer workers"""
|
||||||
|
|
||||||
steps_per_epoch: Optional[int] = Field(
|
steps_per_epoch: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=
|
description=
|
||||||
@ -59,4 +63,7 @@ class ZenFlowConfig(DeepSpeedConfigModel):
|
|||||||
if not isinstance(self.full_warm_up_rounds, int):
|
if not isinstance(self.full_warm_up_rounds, int):
|
||||||
raise ValueError('full_warm_up_rounds must be an integer')
|
raise ValueError('full_warm_up_rounds must be an integer')
|
||||||
|
|
||||||
|
if not isinstance(self.pt_reserved_cores_perc, float):
|
||||||
|
raise ValueError('pt_reserved_cores_perc must be a float')
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
# DeepSpeed Team
|
# DeepSpeed Team
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from deepspeed import comm as dist
|
from deepspeed import comm as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
@ -94,6 +96,7 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
self.micro_step = -1
|
self.micro_step = -1
|
||||||
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
|
self.full_warm_up_rounds = zenflow_config.full_warm_up_rounds
|
||||||
self.offload_selective_optimizer = zenflow_config.offload
|
self.offload_selective_optimizer = zenflow_config.offload
|
||||||
|
self.pt_reserved_cores_perc = zenflow_config.pt_reserved_cores_perc
|
||||||
|
|
||||||
if self.offload_selective_optimizer:
|
if self.offload_selective_optimizer:
|
||||||
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
|
assert overlap_comm, "offload selective optimizer should be used with overlap_comm"
|
||||||
@ -192,7 +195,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
|
|
||||||
self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name())
|
self.index_buffer = torch.empty(total_size, dtype=torch.int32, device=get_accelerator().current_device_name())
|
||||||
|
|
||||||
# count = 0
|
|
||||||
bucket = self.ipg_buckets[communication_data_type]
|
bucket = self.ipg_buckets[communication_data_type]
|
||||||
for i, param_idx_in_group, param_id in bucket.params:
|
for i, param_idx_in_group, param_id in bucket.params:
|
||||||
param = self.bit16_groups[i][param_idx_in_group]
|
param = self.bit16_groups[i][param_idx_in_group]
|
||||||
@ -309,7 +311,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
|
|
||||||
group_to_paramlist = {}
|
group_to_paramlist = {}
|
||||||
|
|
||||||
# count = 0
|
|
||||||
for i, param_idx_in_group, param_id in bucket.params:
|
for i, param_idx_in_group, param_id in bucket.params:
|
||||||
param = self.bit16_groups[i][param_idx_in_group]
|
param = self.bit16_groups[i][param_idx_in_group]
|
||||||
|
|
||||||
@ -478,7 +479,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
curr_selected_reduce_size = 0
|
curr_selected_reduce_size = 0
|
||||||
|
|
||||||
process_group = self.dp_process_group
|
process_group = self.dp_process_group
|
||||||
# count = 0
|
|
||||||
bucket = self.ipg_buckets[communication_data_type]
|
bucket = self.ipg_buckets[communication_data_type]
|
||||||
for i, param_idx_in_group, param_id in bucket.params:
|
for i, param_idx_in_group, param_id in bucket.params:
|
||||||
param = self.bit16_groups[i][param_idx_in_group]
|
param = self.bit16_groups[i][param_idx_in_group]
|
||||||
@ -507,10 +507,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
for idx in range(len(partition_ids_w_offsets)):
|
for idx in range(len(partition_ids_w_offsets)):
|
||||||
partition_id, offset = partition_ids_w_offsets[idx]
|
partition_id, offset = partition_ids_w_offsets[idx]
|
||||||
|
|
||||||
# if dist.get_rank() == 0 and count < 100:
|
|
||||||
# print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
|
|
||||||
# count += 1
|
|
||||||
|
|
||||||
# Calculate numel for grad slice depending on partition location
|
# Calculate numel for grad slice depending on partition location
|
||||||
if idx == len(partition_ids_w_offsets) - 1:
|
if idx == len(partition_ids_w_offsets) - 1:
|
||||||
# Last partition_id uses its own offset
|
# Last partition_id uses its own offset
|
||||||
@ -563,7 +559,6 @@ class ZenFlowZeroOptimizer(DeepSpeedZeroOptimizer):
|
|||||||
|
|
||||||
if self.is_zenflow_select_boundary():
|
if self.is_zenflow_select_boundary():
|
||||||
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start()
|
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).start()
|
||||||
# print("update selected")
|
|
||||||
self.update_selected_channels(tensor, curr_column_size, communication_data_type)
|
self.update_selected_channels(tensor, curr_column_size, communication_data_type)
|
||||||
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop()
|
self.timers(SELECTIVE_OPTIMIZER_UPDATE_TIMER).stop()
|
||||||
elif self.zenflow:
|
elif self.zenflow:
|
||||||
@ -652,27 +647,12 @@ def disable_accelerator():
|
|||||||
|
|
||||||
|
|
||||||
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
|
def zenflow_optimizer_process(pipe, curr_rank, total_rank, param_groups, shared_overlap_grad_map,
|
||||||
shared_stale_param_map):
|
shared_stale_param_map, zf_affinity):
|
||||||
disable_accelerator()
|
disable_accelerator()
|
||||||
|
|
||||||
TOTAL_CORES = os.cpu_count()
|
current_process = psutil.Process()
|
||||||
CPUADAM_CORE_START = 0
|
current_process.cpu_affinity(zf_affinity)
|
||||||
CPUADAM_CORE_END = TOTAL_CORES
|
os.environ['OMP_NUM_THREADS'] = str(len(zf_affinity))
|
||||||
TOTAL_CORES = CPUADAM_CORE_END - CPUADAM_CORE_START
|
|
||||||
|
|
||||||
cores_per_rank = TOTAL_CORES // total_rank
|
|
||||||
extra = TOTAL_CORES % total_rank
|
|
||||||
start_offset = curr_rank * cores_per_rank + min(curr_rank, extra)
|
|
||||||
end_offset = start_offset + cores_per_rank + (1 if curr_rank < extra else 0)
|
|
||||||
assigned_cores = set(range(CPUADAM_CORE_START + start_offset, CPUADAM_CORE_START + end_offset))
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.sched_setaffinity(0, assigned_cores)
|
|
||||||
print(f"[Optimizer Thread] Rank {curr_rank} bound to CPU cores: {os.sched_getaffinity(0)}", flush=True)
|
|
||||||
except AttributeError:
|
|
||||||
print("[Optimizer Thread] sched_setaffinity not supported on this system.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[Optimizer Thread] Failed to set affinity: {e}")
|
|
||||||
|
|
||||||
from deepspeed.ops.adam import ZenFlowCPUAdam
|
from deepspeed.ops.adam import ZenFlowCPUAdam
|
||||||
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
|
optimizer = ZenFlowCPUAdam(param_groups, overlap_step=True)
|
||||||
@ -779,6 +759,14 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
|
|||||||
dest_tensor.copy_(src_tensor, non_blocking=True)
|
dest_tensor.copy_(src_tensor, non_blocking=True)
|
||||||
param.grad = None #offload only
|
param.grad = None #offload only
|
||||||
|
|
||||||
|
# check if all tensors in the list are equal to each other
|
||||||
|
def all_tensors_equal(self, tensor_list):
|
||||||
|
first_tensor = tensor_list[0]
|
||||||
|
for tensor in tensor_list[1:]:
|
||||||
|
if not torch.equal(first_tensor, tensor):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def start_optimizer_process(self):
|
def start_optimizer_process(self):
|
||||||
from multiprocessing import Pipe, get_context, Manager
|
from multiprocessing import Pipe, get_context, Manager
|
||||||
|
|
||||||
@ -807,13 +795,43 @@ class ZenFlowZeroOptimizerParallel(ZenFlowZeroOptimizer):
|
|||||||
curr_rank = dist.get_rank()
|
curr_rank = dist.get_rank()
|
||||||
total_rank = dist.get_world_size()
|
total_rank = dist.get_world_size()
|
||||||
|
|
||||||
|
current_process = psutil.Process()
|
||||||
|
current_affinity = current_process.cpu_affinity()
|
||||||
|
all_affinities = [
|
||||||
|
torch.zeros(len(current_affinity),
|
||||||
|
dtype=type(current_affinity[0]),
|
||||||
|
device=get_accelerator().current_device_name()) for _ in range(total_rank)
|
||||||
|
]
|
||||||
|
dist.all_gather(
|
||||||
|
all_affinities,
|
||||||
|
torch.tensor(current_affinity,
|
||||||
|
dtype=type(current_affinity[0]),
|
||||||
|
device=get_accelerator().current_device_name()))
|
||||||
|
# When affinity across all ranks are the same, the workers are not binded. Do a soft bind here
|
||||||
|
if self.all_tensors_equal(all_affinities):
|
||||||
|
num_phy_cores = psutil.cpu_count(logical=False)
|
||||||
|
available_phy_cores = [i for i in current_affinity if i < num_phy_cores]
|
||||||
|
num_available_phy_cores = len(available_phy_cores)
|
||||||
|
my_rank = curr_rank
|
||||||
|
my_size = total_rank
|
||||||
|
cores_per_rank = num_available_phy_cores // my_size
|
||||||
|
current_affinity = available_phy_cores[my_rank * cores_per_rank:(my_rank + 1) * cores_per_rank]
|
||||||
|
pt_num_cores = math.ceil(self.pt_reserved_cores_perc * len(current_affinity))
|
||||||
|
if pt_num_cores > 0 and pt_num_cores < len(current_affinity):
|
||||||
|
zf_affinity = current_affinity[pt_num_cores:]
|
||||||
|
pt_affinity = current_affinity[:pt_num_cores]
|
||||||
|
else:
|
||||||
|
zf_affinity = current_affinity
|
||||||
|
pt_affinity = current_affinity
|
||||||
self.process = ctx.Process(
|
self.process = ctx.Process(
|
||||||
target=zenflow_optimizer_process,
|
target=zenflow_optimizer_process,
|
||||||
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
|
args=(self.child_conn, curr_rank, total_rank, param_groups_data, self.shared_overlap_grad_map,
|
||||||
self.shared_stale_param_map),
|
self.shared_stale_param_map, zf_affinity),
|
||||||
)
|
)
|
||||||
self.process.daemon = True
|
self.process.daemon = True
|
||||||
self.process.start()
|
self.process.start()
|
||||||
|
current_process.cpu_affinity(pt_affinity)
|
||||||
|
os.environ['OMP_NUM_THREADS'] = str(len(pt_affinity))
|
||||||
|
|
||||||
msg = self.parent_conn.recv()
|
msg = self.parent_conn.recv()
|
||||||
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
|
assert msg["type"] == "ready", "Optimizer process did not initialize correctly."
|
||||||
|
Reference in New Issue
Block a user