mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
In the process of adding onebit optimizers support for XPU devices, we have noticed that for different accelerator, the main difference of implementation of `compressed_allreduce` lies on `packbits` and `unpackbits`. CUDA uses cupy and NPU uses torch_npu. Instead of replace these to xpu only functions, we provided a CompressedBackend to do the `compressed_allreduce` work where users can add their own packbits/unpackbits kernels, which is a general path for all kinds of accelerators. In this PR, we: 1. Add CompressedBackend for onebitAdam, onebitLamb and zerooneAdam 2. Add XPU implement of packbits/unpackbits with SYCL, built in PackbitsBuilder 3. Add tests for onebit with CompressedBackend --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
import deepspeed.comm as dist
|
|
import numpy as np
|
|
import argparse
|
|
import deepspeed
|
|
import os
|
|
|
|
from deepspeed.runtime.comm.compressed import CompressedBackend
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--local_rank', type=int, default=-1)
|
|
args = parser.parse_args()
|
|
|
|
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
|
args.local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
|
get_accelerator().set_device(args.local_rank)
|
|
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
|
|
|
size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
backend = CompressedBackend()
|
|
local_rank = args.local_rank
|
|
|
|
|
|
# A simulated compression function using deepspeed.comm
|
|
def torch_sim(a):
|
|
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
|
|
scale = a.norm() / np.sqrt(a.numel())
|
|
a_compressed = scale * a_sign
|
|
a_sign = None
|
|
worker_error = a - a_compressed
|
|
dist.all_reduce(a_compressed)
|
|
a_compressed.mul_(1 / dist.get_world_size())
|
|
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
|
|
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
|
|
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
|
|
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
|
|
a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
|
|
rank = dist.get_rank()
|
|
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
|
|
get_accelerator().synchronize()
|
|
dist.barrier()
|
|
return a_server_compressed, worker_error, server_error
|
|
|
|
|
|
tensor_size = 300 * 2**20
|
|
server_size = int(tensor_size / size)
|
|
if tensor_size % (8 * size) != 0:
|
|
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
|
|
else:
|
|
right_tensor_size = tensor_size
|
|
right_server_size = right_tensor_size // size
|
|
|
|
# Adding bias to the initialization of the gradient we are communicating
|
|
# In order to get rid of the case where some elements in the gradient are too small
|
|
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
|
|
|
|
worker_error = torch.zeros(right_tensor_size, device=device)
|
|
server_error = torch.zeros(right_server_size, device=device)
|
|
|
|
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
|
|
get_accelerator().empty_cache()
|
|
|
|
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
|
|
|
|
print(a_torch.cpu())
|
|
print(a_after.cpu())
|
|
|
|
threshold = 1e-6
|
|
magnitude_threshold = 1e-6
|
|
diff_mask = (a_after - a_torch) > threshold
|
|
diff_server_mask = torch.chunk(diff_mask, size)[rank]
|
|
mpi_server = torch.chunk(a_after, size)[rank] + server_error
|
|
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
|
|
|
|
test_correctness = True
|
|
|
|
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
|
|
# The test would skip those numbers that are too small in compensated_server_m
|
|
if test_correctness:
|
|
if torch.sum(diff_server_mask) == 0:
|
|
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
|
|
else:
|
|
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
|
|
if torch.sum(check_mag_mask) == 0:
|
|
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
|
|
else:
|
|
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
|