Files
DeepSpeed/deepspeed/compile/profilers/comm_profile.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

172 lines
5.0 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
import deepspeed
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
def sync_all():
get_accelerator().synchronize()
dist.barrier()
def get_bw(comm_op, size, duration):
n = dist.get_world_size()
tput = 0
busbw = 0
if duration == 0:
raise ValueError("Error. Duration is 0.")
if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_reduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "pt2pt" or comm_op == "broadcast":
tput = (size / duration)
busbw = tput
else:
raise ValueError("wrong comm_op specified")
return tput, busbw
# Run all_gather and print metrics
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
sync_all()
# Warmups, establish connections, etc.
for i in range(warmup):
dist.all_gather_into_tensor(output, input, async_op=async_op)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for i in range(trials):
dist.all_gather_into_tensor(output, input, async_op=async_op)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / trials
size = input.element_size() * input.nelement() * dist.get_world_size()
# tput, busbw = get_bw('all_gather', size, avg_duration)
avg_duration_ten = torch.tensor([avg_duration], device=device)
if dist.get_world_size() > 1:
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)
return size, avg_duration_ten.item()
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):
# Prepare benchmark header
global_rank = dist.get_rank()
world_size = dist.get_world_size()
start_event = get_accelerator().Event(enable_timing=True)
end_event = get_accelerator().Event(enable_timing=True)
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, maxsize)):
m = x // world_size
if m > 0:
M_LIST.append(m)
results = [(0, 0)]
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(M, dtype=dtype, device=device)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))
return results
profile_results = None
def create_predictor():
global profile_results
if profile_results is None:
with unset_fake_temporarily():
device = get_accelerator().current_device()
profile_results = run_all_gather(device, torch.bfloat16, 31)
if dist.get_rank() == 0:
for size, avg_duration in profile_results:
print(f"size: {size}, avg_duration: {avg_duration}")
# Extract size and avg_duration from results
sizes = [result[0] for result in profile_results]
durations = [result[1] for result in profile_results]
try:
from scipy.interpolate import interp1d
except ImportError:
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")
def f(size):
if size == 0:
return 0
return predictor(size)
# Create an interpolation function
return f
if __name__ == "__main__":
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)
print(f"local_rank={local_rank}")
deepspeed.init_distributed(dist_backend='nccl')
# Create predictor function
predictor = create_predictor()
# Predict time for a specific data size
example_size = 1e9
predicted_time = predictor(example_size)
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")
dist.destroy_process_group()