mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
172 lines
5.0 KiB
Python
172 lines
5.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
import torch
|
|
|
|
try:
|
|
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
|
except ImportError:
|
|
# Unsupported torch version
|
|
pass
|
|
|
|
import deepspeed
|
|
import deepspeed.comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
|
|
def sync_all():
|
|
get_accelerator().synchronize()
|
|
dist.barrier()
|
|
|
|
|
|
def get_bw(comm_op, size, duration):
|
|
n = dist.get_world_size()
|
|
tput = 0
|
|
busbw = 0
|
|
|
|
if duration == 0:
|
|
raise ValueError("Error. Duration is 0.")
|
|
|
|
if comm_op == "all_to_all":
|
|
tput = (size / duration)
|
|
busbw = (size / duration) * ((n - 1) / n)
|
|
elif comm_op == "all_gather":
|
|
size *= n
|
|
tput = (size / duration)
|
|
busbw = (size / duration) * ((n - 1) / n)
|
|
elif comm_op == "all_reduce":
|
|
tput = (size * 2 / duration)
|
|
busbw = (size / duration) * (2 * (n - 1) / n)
|
|
elif comm_op == "pt2pt" or comm_op == "broadcast":
|
|
tput = (size / duration)
|
|
busbw = tput
|
|
else:
|
|
raise ValueError("wrong comm_op specified")
|
|
|
|
return tput, busbw
|
|
|
|
|
|
# Run all_gather and print metrics
|
|
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
|
|
sync_all()
|
|
# Warmups, establish connections, etc.
|
|
for i in range(warmup):
|
|
dist.all_gather_into_tensor(output, input, async_op=async_op)
|
|
sync_all()
|
|
|
|
# time the actual comm op trials times and average it
|
|
start_event.record()
|
|
for i in range(trials):
|
|
dist.all_gather_into_tensor(output, input, async_op=async_op)
|
|
end_event.record()
|
|
sync_all()
|
|
duration = start_event.elapsed_time(end_event) / 1000
|
|
|
|
# maintain and clean performance data
|
|
avg_duration = duration / trials
|
|
size = input.element_size() * input.nelement() * dist.get_world_size()
|
|
# tput, busbw = get_bw('all_gather', size, avg_duration)
|
|
|
|
avg_duration_ten = torch.tensor([avg_duration], device=device)
|
|
if dist.get_world_size() > 1:
|
|
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)
|
|
|
|
return size, avg_duration_ten.item()
|
|
|
|
|
|
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):
|
|
|
|
# Prepare benchmark header
|
|
global_rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
start_event = get_accelerator().Event(enable_timing=True)
|
|
end_event = get_accelerator().Event(enable_timing=True)
|
|
|
|
# Create list of message sizes
|
|
M_LIST = []
|
|
for x in (2**p for p in range(1, maxsize)):
|
|
m = x // world_size
|
|
if m > 0:
|
|
M_LIST.append(m)
|
|
|
|
results = [(0, 0)]
|
|
sync_all()
|
|
# loop over various tensor sizes
|
|
for M in M_LIST:
|
|
global_rank = dist.get_rank()
|
|
try:
|
|
mat = torch.ones(M, dtype=dtype, device=device)
|
|
sync_all()
|
|
input = ((mat.mul_(float(global_rank))).view(-1))
|
|
# Delete original mat to avoid OOM
|
|
del mat
|
|
get_accelerator().empty_cache()
|
|
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
|
|
except RuntimeError as e:
|
|
if 'out of memory' in str(e):
|
|
if dist.get_rank() == 0:
|
|
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
|
sync_all()
|
|
break
|
|
else:
|
|
raise e
|
|
sync_all()
|
|
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))
|
|
|
|
return results
|
|
|
|
|
|
profile_results = None
|
|
|
|
|
|
def create_predictor():
|
|
global profile_results
|
|
if profile_results is None:
|
|
with unset_fake_temporarily():
|
|
device = get_accelerator().current_device()
|
|
profile_results = run_all_gather(device, torch.bfloat16, 31)
|
|
if dist.get_rank() == 0:
|
|
for size, avg_duration in profile_results:
|
|
print(f"size: {size}, avg_duration: {avg_duration}")
|
|
|
|
# Extract size and avg_duration from results
|
|
sizes = [result[0] for result in profile_results]
|
|
durations = [result[1] for result in profile_results]
|
|
|
|
try:
|
|
from scipy.interpolate import interp1d
|
|
except ImportError:
|
|
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")
|
|
|
|
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")
|
|
|
|
def f(size):
|
|
if size == 0:
|
|
return 0
|
|
return predictor(size)
|
|
|
|
# Create an interpolation function
|
|
return f
|
|
|
|
|
|
if __name__ == "__main__":
|
|
local_rank = int(os.environ['LOCAL_RANK'])
|
|
get_accelerator().set_device(local_rank)
|
|
print(f"local_rank={local_rank}")
|
|
|
|
deepspeed.init_distributed(dist_backend='nccl')
|
|
|
|
# Create predictor function
|
|
predictor = create_predictor()
|
|
|
|
# Predict time for a specific data size
|
|
example_size = 1e9
|
|
predicted_time = predictor(example_size)
|
|
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")
|
|
|
|
dist.destroy_process_group()
|