mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Add Compressedbackend for Onebit optimizers (#5473)
In the process of adding onebit optimizers support for XPU devices, we have noticed that for different accelerator, the main difference of implementation of `compressed_allreduce` lies on `packbits` and `unpackbits`. CUDA uses cupy and NPU uses torch_npu. Instead of replace these to xpu only functions, we provided a CompressedBackend to do the `compressed_allreduce` work where users can add their own packbits/unpackbits kernels, which is a general path for all kinds of accelerators. In this PR, we: 1. Add CompressedBackend for onebitAdam, onebitLamb and zerooneAdam 2. Add XPU implement of packbits/unpackbits with SYCL, built in PackbitsBuilder 3. Add tests for onebit with CompressedBackend --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -267,9 +267,9 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
||||
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
|
||||
# if successful this also means we're doing a local install and not JIT compile path
|
||||
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
|
||||
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
from op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
|
||||
except ImportError:
|
||||
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder
|
||||
from deepspeed.ops.op_builder.xpu import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, AsyncIOBuilder, PackbitsBuilder
|
||||
|
||||
if class_name == "AsyncIOBuilder":
|
||||
return AsyncIOBuilder
|
||||
@ -279,6 +279,8 @@ class XPU_Accelerator(DeepSpeedAccelerator):
|
||||
return CPUAdamBuilder
|
||||
elif class_name == "FusedAdamBuilder":
|
||||
return FusedAdamBuilder
|
||||
elif class_name == "PackbitsBuilder":
|
||||
return PackbitsBuilder
|
||||
else:
|
||||
return None
|
||||
|
||||
|
100
csrc/xpu/packbits/packing.cpp
Normal file
100
csrc/xpu/packbits/packing.cpp
Normal file
@ -0,0 +1,100 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include <ipex.h>
|
||||
#include <torch/extension.h>
|
||||
#include <iostream>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
using namespace sycl;
|
||||
using namespace xpu;
|
||||
|
||||
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1)
|
||||
{
|
||||
// get the sign bit of each float and pack them into byte
|
||||
int i = item_ct1;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
int k = i * 8 + j;
|
||||
int bit = k < input_size && (!sycl::signbit(input[k]));
|
||||
output[i] |= bit << (7 - j);
|
||||
}
|
||||
}
|
||||
|
||||
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1)
|
||||
{
|
||||
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1
|
||||
int i = item_ct1;
|
||||
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2;
|
||||
}
|
||||
|
||||
sycl::queue get_current_queue(at::Device device)
|
||||
{
|
||||
c10::impl::VirtualGuardImpl impl(device.type());
|
||||
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
|
||||
sycl::queue queue = xpu::get_queue_from_stream(_stream);
|
||||
return queue;
|
||||
}
|
||||
|
||||
/*
|
||||
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8
|
||||
if float x >= 0, will be packed as a '1' bit, or will be packed as '0'
|
||||
Arguments:
|
||||
tensor: A bool tensor that get packed.
|
||||
input_size: numel of input tensor
|
||||
rank: device id in order to get corresponding stream
|
||||
*/
|
||||
at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
|
||||
{
|
||||
at::Device device = "xpu:" + std::to_string(rank);
|
||||
sycl::queue q = get_current_queue(device);
|
||||
|
||||
int packed_size = (input_size + 7) / 8;
|
||||
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU);
|
||||
at::Tensor packed = torch::zeros({packed_size}, unit8_options);
|
||||
|
||||
float* input = (float*)tensor.data_ptr();
|
||||
uint8_t* output = (uint8_t*)packed.data_ptr();
|
||||
|
||||
auto event = q.submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) {
|
||||
packbitskernel(input, output, input_size, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
return packed;
|
||||
}
|
||||
|
||||
/*
|
||||
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float
|
||||
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1).
|
||||
Arguments:
|
||||
tensor: A uint8 tensor that get unpacked.
|
||||
input_size: numel of input tensor
|
||||
rank: device id in order to get corresponding stream
|
||||
*/
|
||||
at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
|
||||
{
|
||||
at::Device device = "xpu:" + std::to_string(rank);
|
||||
sycl::queue q = get_current_queue(device);
|
||||
|
||||
auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU);
|
||||
at::Tensor unpacked = torch::empty({input_size * 8}, float_options);
|
||||
|
||||
uint8_t* input = (uint8_t*)tensor.data_ptr();
|
||||
float* output = (float*)unpacked.data_ptr();
|
||||
|
||||
auto event = q.submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for<>(range(input_size * 8),
|
||||
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); });
|
||||
});
|
||||
|
||||
return unpacked;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)");
|
||||
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)");
|
||||
}
|
137
deepspeed/runtime/comm/compressed.py
Normal file
137
deepspeed/runtime/comm/compressed.py
Normal file
@ -0,0 +1,137 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.ops.op_builder import PackbitsBuilder
|
||||
|
||||
|
||||
class CompressedBackend(object):
|
||||
|
||||
def __init__(self, mpu=None):
|
||||
if mpu is None:
|
||||
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
|
||||
else:
|
||||
self.mpu = mpu
|
||||
self.world_group = self.mpu.get_data_parallel_group()
|
||||
self.size = dist.get_world_size(group=self.world_group)
|
||||
self.rank = dist.get_rank(group=self.world_group)
|
||||
self.packer = PackbitsBuilder().load()
|
||||
|
||||
def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
|
||||
req = []
|
||||
if rank == root:
|
||||
for idx in range(size):
|
||||
if idx != rank:
|
||||
req.append(dist.irecv(recvbuf[idx], src=idx, group=group))
|
||||
else:
|
||||
recvbuf[rank] = sendbuf
|
||||
else:
|
||||
req.append(dist.isend(sendbuf, group=group, dst=root))
|
||||
return req
|
||||
|
||||
def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
|
||||
if rank == root:
|
||||
for idx in range(size):
|
||||
if idx != rank:
|
||||
dist.recv(recvbuf[idx], src=idx, group=group)
|
||||
else:
|
||||
recvbuf[rank] = sendbuf
|
||||
else:
|
||||
dist.send(sendbuf, group=group, dst=root)
|
||||
|
||||
def pack(self, buffer, size):
|
||||
# pack float tensor into uint8 tensor
|
||||
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank)
|
||||
return packed.reshape(size, -1)
|
||||
|
||||
def unpack(self, buffer, size, dtype):
|
||||
# unpack uint8 to float tensor
|
||||
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank)
|
||||
return unpacked.reshape(size, -1).to(dtype)
|
||||
|
||||
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
|
||||
original_shape = buffer_m.size()
|
||||
if len(original_shape) > 1:
|
||||
buffer_m = torch.flatten(buffer_m)
|
||||
|
||||
# align size of original_buffer and error
|
||||
original_size = buffer_m.numel()
|
||||
worker_error_size = worker_error.numel()
|
||||
if original_size != worker_error_size:
|
||||
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
|
||||
buffer_m = torch.cat([buffer_m, empty_tensor])
|
||||
|
||||
buffer_m.add_(worker_error)
|
||||
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
|
||||
|
||||
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
|
||||
|
||||
sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8)
|
||||
|
||||
recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])],
|
||||
dtype=sign_list_packed_tmp[0].dtype,
|
||||
device=sign_list_packed_tmp.device)
|
||||
|
||||
sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)]
|
||||
|
||||
recvbuf_scale = [
|
||||
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name())
|
||||
for _ in range(self.size)
|
||||
]
|
||||
|
||||
# communication phase 1
|
||||
# all to all for sign
|
||||
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
|
||||
# all gather for scale
|
||||
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
|
||||
|
||||
flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten()
|
||||
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \
|
||||
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
|
||||
|
||||
compensated_server_m.add_(server_error)
|
||||
|
||||
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
|
||||
|
||||
server_error.set_(compensated_server_m -
|
||||
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
|
||||
|
||||
server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8)
|
||||
|
||||
# recvbuf_sign_server
|
||||
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])],
|
||||
dtype=recvbuf_sign.dtype,
|
||||
device=server_sign_packed.device)
|
||||
|
||||
recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)]
|
||||
|
||||
# recvbuf_scale_server
|
||||
recvbuf_scale_server_tmp = torch.zeros([self.size, 1],
|
||||
dtype=worker_scale.dtype,
|
||||
device=server_sign_packed.device)
|
||||
|
||||
recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)]
|
||||
|
||||
# communication Phase 2
|
||||
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
|
||||
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
|
||||
|
||||
recvbuf_sign_server = torch.stack(recvbuf_sign_server)
|
||||
|
||||
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten()
|
||||
|
||||
buffer_m.data.copy_(
|
||||
self.unpack(flattened_recvbuf_sign_server, self.size,
|
||||
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data)
|
||||
|
||||
if original_size != worker_error_size:
|
||||
buffer_m = buffer_m[0:original_size]
|
||||
if len(original_shape) > 1:
|
||||
buffer_m = buffer_m.reshape(original_shape)
|
||||
|
||||
return buffer_m
|
@ -101,6 +101,10 @@ class OnebitAdam(torch.optim.Optimizer):
|
||||
from deepspeed.runtime.comm.hccl import HcclBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
|
||||
elif self.comm_backend_name == 'compressed':
|
||||
from deepspeed.runtime.comm.compressed import CompressedBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)
|
||||
self.size = self.comm_backend_handle.size
|
||||
|
||||
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
|
||||
|
@ -123,6 +123,10 @@ class OnebitLamb(torch.optim.Optimizer):
|
||||
from deepspeed.runtime.comm.hccl import HcclBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
|
||||
elif self.comm_backend_name == 'compressed':
|
||||
from deepspeed.runtime.comm.compressed import CompressedBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)
|
||||
|
||||
self.size = self.comm_backend_handle.size
|
||||
|
||||
|
@ -114,6 +114,10 @@ class ZeroOneAdam(torch.optim.Optimizer):
|
||||
from deepspeed.runtime.comm.hccl import HcclBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
|
||||
elif self.comm_backend_name == 'compressed':
|
||||
from deepspeed.runtime.comm.compressed import CompressedBackend
|
||||
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
|
||||
self.comm_backend_handle = CompressedBackend(self.deepspeed.mpu)
|
||||
self.size = self.comm_backend_handle.size
|
||||
|
||||
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
|
||||
|
@ -7,3 +7,4 @@ from .cpu_adam import CPUAdamBuilder
|
||||
from .cpu_adagrad import CPUAdagradBuilder
|
||||
from .fused_adam import FusedAdamBuilder
|
||||
from .async_io import AsyncIOBuilder
|
||||
from .packbits import PackbitsBuilder
|
||||
|
26
op_builder/xpu/packbits.py
Normal file
26
op_builder/xpu/packbits.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
from .builder import SYCLOpBuilder
|
||||
|
||||
|
||||
class PackbitsBuilder(SYCLOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_PACK_BITS"
|
||||
NAME = "pack_bits"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return ['csrc/xpu/packbits/packing.cpp']
|
||||
|
||||
def include_paths(self):
|
||||
return ['csrc/xpu/includes']
|
||||
|
||||
def cxx_args(self):
|
||||
args = super().cxx_args()
|
||||
return args + self.version_dependent_macros()
|
31
tests/onebit/README.md
Normal file
31
tests/onebit/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
# One-Bit tests
|
||||
|
||||
In this folder, you can test the functionality and performance of different backend for doing compressed allreduce, which is the main algorithm in one-bit optimizers like [One-Bit Adam](https://www.deepspeed.ai/tutorials/onebit-adam/), [One-Bit Lamb](https://www.deepspeed.ai/tutorials/onebit-lamb/) and [Zero-One Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/).
|
||||
|
||||
## How to run
|
||||
|
||||
### NCCL and MPI backend
|
||||
|
||||
Basically it requires your environment have relative communication backend installed, the NCCL backend of PyTorch distributed or Message Passing Interface (MPI) like MVAPICH2-GDR and OpenMPI. [Detailed Pre-requisites](https://www.deepspeed.ai/tutorials/zero-one-adam/#12-pre-requisites-for-01-adam).
|
||||
|
||||
To test accuracy and performance of NCCL backend:
|
||||
```bash
|
||||
python test_nccl_backend.py
|
||||
python test_nccl_perf.py
|
||||
```
|
||||
Similarly, for MPI backend:
|
||||
```bash
|
||||
python test_mpi_backend.py
|
||||
python test_mpi_perf.py
|
||||
```
|
||||
|
||||
### Compressed backend
|
||||
|
||||
This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend` and test it, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype.
|
||||
An example can be found in `Deepspeed/op_builder/xpu/packbits.py`.
|
||||
|
||||
The test usage is same as others:
|
||||
```bash
|
||||
python test_compressed_backend.py
|
||||
python test_compressed_perf.py
|
||||
```
|
96
tests/onebit/test_compressed_backend.py
Normal file
96
tests/onebit/test_compressed_backend.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import deepspeed.comm as dist
|
||||
import numpy as np
|
||||
import argparse
|
||||
import deepspeed
|
||||
import os
|
||||
|
||||
from deepspeed.runtime.comm.compressed import CompressedBackend
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
args = parser.parse_args()
|
||||
|
||||
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||
args.local_rank = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
get_accelerator().set_device(args.local_rank)
|
||||
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
||||
|
||||
size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
backend = CompressedBackend()
|
||||
local_rank = args.local_rank
|
||||
|
||||
|
||||
# A simulated compression function using deepspeed.comm
|
||||
def torch_sim(a):
|
||||
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
|
||||
scale = a.norm() / np.sqrt(a.numel())
|
||||
a_compressed = scale * a_sign
|
||||
a_sign = None
|
||||
worker_error = a - a_compressed
|
||||
dist.all_reduce(a_compressed)
|
||||
a_compressed.mul_(1 / dist.get_world_size())
|
||||
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
|
||||
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
|
||||
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
|
||||
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
|
||||
a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
|
||||
rank = dist.get_rank()
|
||||
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
|
||||
get_accelerator().synchronize()
|
||||
dist.barrier()
|
||||
return a_server_compressed, worker_error, server_error
|
||||
|
||||
|
||||
tensor_size = 300 * 2**20
|
||||
server_size = int(tensor_size / size)
|
||||
if tensor_size % (8 * size) != 0:
|
||||
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
|
||||
else:
|
||||
right_tensor_size = tensor_size
|
||||
right_server_size = right_tensor_size // size
|
||||
|
||||
# Adding bias to the initialization of the gradient we are communicating
|
||||
# In order to get rid of the case where some elements in the gradient are too small
|
||||
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
|
||||
|
||||
worker_error = torch.zeros(right_tensor_size, device=device)
|
||||
server_error = torch.zeros(right_server_size, device=device)
|
||||
|
||||
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)
|
||||
|
||||
print(a_torch.cpu())
|
||||
print(a_after.cpu())
|
||||
|
||||
threshold = 1e-6
|
||||
magnitude_threshold = 1e-6
|
||||
diff_mask = (a_after - a_torch) > threshold
|
||||
diff_server_mask = torch.chunk(diff_mask, size)[rank]
|
||||
mpi_server = torch.chunk(a_after, size)[rank] + server_error
|
||||
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch
|
||||
|
||||
test_correctness = True
|
||||
|
||||
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
|
||||
# The test would skip those numbers that are too small in compensated_server_m
|
||||
if test_correctness:
|
||||
if torch.sum(diff_server_mask) == 0:
|
||||
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
|
||||
else:
|
||||
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
|
||||
if torch.sum(check_mag_mask) == 0:
|
||||
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
|
||||
else:
|
||||
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
|
97
tests/onebit/test_compressed_perf.py
Normal file
97
tests/onebit/test_compressed_perf.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import deepspeed.comm as dist
|
||||
import numpy as np
|
||||
import argparse
|
||||
import deepspeed
|
||||
import os
|
||||
|
||||
from deepspeed.runtime.comm.compressed import CompressedBackend
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from statistics import mean
|
||||
|
||||
timers = SynchronizedWallClockTimer()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
args = parser.parse_args()
|
||||
|
||||
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
|
||||
args.local_rank = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
get_accelerator().set_device(args.local_rank)
|
||||
device = torch.device(get_accelerator().device_name(), args.local_rank)
|
||||
|
||||
size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
backend = CompressedBackend()
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Setting tensor_size (BERT-Large)
|
||||
tensor_size = 300 * 2**20
|
||||
server_size = int(tensor_size / size)
|
||||
if tensor_size % (8 * size) != 0:
|
||||
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
|
||||
else:
|
||||
right_tensor_size = tensor_size
|
||||
right_server_size = right_tensor_size // size
|
||||
|
||||
# Adding bias to the initialization of the gradient we are communicating
|
||||
# In order to get rid of the case where some elements in the gradient are too small
|
||||
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank
|
||||
|
||||
worker_error = torch.zeros(right_tensor_size, device=device)
|
||||
server_error = torch.zeros(right_server_size, device=device)
|
||||
|
||||
warmup = 10
|
||||
iters = 10
|
||||
|
||||
# Warmup
|
||||
for i in range(warmup):
|
||||
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
|
||||
|
||||
time_list = []
|
||||
|
||||
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
|
||||
scale = a.norm() / np.sqrt(a.numel())
|
||||
a_compressed = scale * a_sign
|
||||
|
||||
print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None
|
||||
|
||||
for i in range(iters):
|
||||
timers('compressed_allreduce').start()
|
||||
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
|
||||
#deepspeed.comm.all_reduce(a_compressed)
|
||||
timers('compressed_allreduce').stop()
|
||||
time_list.append(timers('compressed_allreduce').elapsed())
|
||||
|
||||
#timer_names = ['compressed_allreduce']
|
||||
#timers.log(names=timer_names, normalizer=1, memory_breakdown=None)
|
||||
|
||||
places = 2
|
||||
convert = 1e3
|
||||
float_size = 4
|
||||
|
||||
if rank == 0:
|
||||
for i in range(iters):
|
||||
lat = time_list[i]
|
||||
print("latency = ", lat * convert)
|
||||
|
||||
minlat = round(min(time_list) * convert)
|
||||
maxlat = round(max(time_list) * convert)
|
||||
meanlat = round(mean(time_list) * convert, places)
|
||||
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None
|
||||
#print("tensor shape", a.shape)
|
||||
duration = meanlat / 1e3
|
||||
tput = ((tensor_size * 4) / duration)
|
||||
print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None
|
||||
size = tensor_size * 4
|
||||
n = dist.get_world_size()
|
||||
busbw = (size / duration) * (2 * (n - 1) / n)
|
||||
print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None
|
Reference in New Issue
Block a user