mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-21 16:48:52 +08:00
### Integration of LoCo Method into ZeRO++ #### Overview This PR introduces the integration of the **LoCo** method, as outlined in [this paper](https://arxiv.org/abs/2407.04480), into the ZeRO++ framework of DeepSpeed. The key enhancement involves applying error feedback compensation to 4-bit gradients before communication. This approach ***improves pre-training loss outcomes without additional time overhead***, though it requires extra GPU memory. The extent of this memory increase depends on model size and training configuration. #### Experimental Results We conducted pre-training experiments using the Llama2 architecture, adjusting the number of layers and hidden size. The experiments included: - **A smaller-scale model with 0.8B parameters trained on 30B tokens**. - **A larger-scale model with 8B parameters trained on 5B tokens**. The training data was sampled from **Redpajama-V2**. <p align="center"> <img src="https://github.com/user-attachments/assets/e7db9487-728c-4a17-9806-c15afa12f62e" width="49%" /> <img src="https://github.com/user-attachments/assets/3efec895-b71d-43ab-b5ce-65468ba8b9f1" width="49%" /> </p> **Findings**: - **Smaller Models (0.8B parameters)**: Significant gains were observed when applying the LoCo method. - **Larger Models (8B parameters)**: The gains were present but less pronounced. This could be due to: 1. Relatively smaller data volume. 2. Lower pre-training loss for larger models, making significant improvements harder to achieve. However, even a smaller pre-training loss gap in larger models can translate to meaningful gains in downstream tasks. #### Example Script For reference, the [run.sh](https://github.com/user-attachments/files/17679552/zeroplus-7b3.zip) script used for the 8B parameter, 5B tokens experiment is attached. The experiment was conducted using the **DeepSpeed-Megatron** platform. #### Acknowledgments Special thanks to cc @GuanhuaWang for ongoing communication and guidance throughout this work. --- We appreciate your consideration of this PR and welcome any feedback or questions! --------- Co-authored-by: ChuanxinTang <tangchuanxin.chn@gmail.com> Co-authored-by: root <pan.jiachun@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams <loadams@microsoft.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
405 lines
17 KiB
C++
405 lines
17 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// DeepSpeed Team
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/extension.h>
|
|
#include <cassert>
|
|
#include <vector>
|
|
#include "quantization.h"
|
|
|
|
template <typename T>
|
|
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
|
|
{
|
|
auto t_size = vals.sizes();
|
|
int size = 1;
|
|
for (auto dim : t_size) size *= dim;
|
|
|
|
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
|
|
launch_fake_quantize_kernel(
|
|
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
template <typename T>
|
|
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
|
|
{
|
|
auto t_size = vals.sizes();
|
|
int size = 1;
|
|
for (auto dim : t_size) size *= dim;
|
|
|
|
if (((size / groups) / 4 / 1024) <= 256) {
|
|
launch_sr_fake_quantize_kernel(
|
|
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
template <typename T>
|
|
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
|
|
{
|
|
auto t_size = vals.sizes();
|
|
int size = 1;
|
|
for (auto dim : t_size) size *= dim;
|
|
|
|
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
|
|
launch_fake_quantize_kernel_asym(
|
|
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
template <typename T>
|
|
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
|
|
{
|
|
auto t_size = vals.sizes();
|
|
int size = 1;
|
|
for (auto dim : t_size) size *= dim;
|
|
|
|
if (((size / groups) / 4 / 1024) <= 256) {
|
|
launch_sr_fake_quantize_kernel_asym(
|
|
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
|
|
int groups,
|
|
int numBits,
|
|
quantize::Type quantType)
|
|
{
|
|
auto dtype = at::kFloat;
|
|
auto params_options = at::TensorOptions()
|
|
.dtype(dtype)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
|
|
auto params = torch::empty({groups, param_elems}, params_options);
|
|
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(at::kChar)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
auto output_sizes = input_vals.sizes().vec();
|
|
output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
|
|
auto output = torch::empty(output_sizes, output_options);
|
|
|
|
const int elems_per_group = at::numel(input_vals) / groups;
|
|
|
|
launch_quant((int8_t*)output.data_ptr(),
|
|
(float*)params.data_ptr(),
|
|
(__half*)input_vals.data_ptr(),
|
|
groups,
|
|
elems_per_group,
|
|
numBits,
|
|
quantType,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return {output, params};
|
|
}
|
|
|
|
template <typename T>
|
|
at::Tensor dequantize(at::Tensor& quantized_data,
|
|
at::Tensor& params,
|
|
int groups,
|
|
int num_bits,
|
|
quantize::Type quant_type)
|
|
{
|
|
auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(dtype)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
auto output_sizes = quantized_data.sizes().vec();
|
|
output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
|
|
auto output = torch::empty(output_sizes, output_options);
|
|
|
|
const int total_elems = at::numel(output);
|
|
const int elems_per_group = total_elems / groups;
|
|
|
|
launch_dequantize_kernel((T*)output.data_ptr(),
|
|
(const int8_t*)quantized_data.data_ptr(),
|
|
(const float*)params.data_ptr(),
|
|
quant_type,
|
|
num_bits,
|
|
elems_per_group,
|
|
total_elems,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return output;
|
|
}
|
|
|
|
at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
|
|
at::Tensor& scale_buffer,
|
|
at::Tensor& min_val_buffer,
|
|
int num_group,
|
|
int group_size)
|
|
{
|
|
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
|
|
auto output = torch::empty({num_group, group_size}, output_options);
|
|
|
|
launch_dequantize_int4_to_half_experimental((uint8_t*)data_in.data_ptr(),
|
|
(half*)output.data_ptr(),
|
|
(half*)scale_buffer.data_ptr(),
|
|
(half*)min_val_buffer.data_ptr(),
|
|
num_group,
|
|
group_size,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return output;
|
|
}
|
|
|
|
at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
|
|
at::Tensor& scale_buffer,
|
|
at::Tensor& min_val_buffer,
|
|
int num_group,
|
|
int group_size)
|
|
{
|
|
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
|
|
auto output = torch::empty({num_group, group_size}, output_options);
|
|
|
|
launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
|
|
(half*)output.data_ptr(),
|
|
(half*)scale_buffer.data_ptr(),
|
|
(half*)min_val_buffer.data_ptr(),
|
|
num_group,
|
|
group_size,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return output;
|
|
}
|
|
|
|
std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
|
|
at::Tensor& error_feedback,
|
|
float err_beta,
|
|
int groups,
|
|
int num_bits,
|
|
quantize::Type quant_type,
|
|
int pipeline_size,
|
|
int nodes,
|
|
int devices_per_node)
|
|
{
|
|
auto scales_options = at::TensorOptions()
|
|
.dtype(at::kFloat)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
|
|
auto scales = torch::empty({groups, scales_elems}, scales_options);
|
|
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(at::kChar)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
const int quantization_scalar = 8 / num_bits;
|
|
const int compressed_vals = at::numel(input_vals) / quantization_scalar;
|
|
|
|
auto output = torch::empty({compressed_vals}, output_options);
|
|
const int elems_per_group = at::numel(input_vals) / groups;
|
|
|
|
launch_loco_swizzled_quant(reinterpret_cast<int8_t*>(output.data_ptr()),
|
|
reinterpret_cast<float*>(scales.data_ptr()),
|
|
reinterpret_cast<const __half*>(input_vals.data_ptr()),
|
|
reinterpret_cast<__half*>(error_feedback.data_ptr()),
|
|
err_beta,
|
|
num_bits,
|
|
quant_type,
|
|
groups,
|
|
elems_per_group,
|
|
pipeline_size,
|
|
nodes,
|
|
devices_per_node,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return {output, scales};
|
|
}
|
|
|
|
std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
|
|
int groups,
|
|
int num_bits,
|
|
quantize::Type quant_type,
|
|
int pipeline_size,
|
|
int nodes,
|
|
int devices_per_node)
|
|
{
|
|
auto scales_options = at::TensorOptions()
|
|
.dtype(at::kFloat)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
|
|
auto scales = torch::empty({groups, scales_elems}, scales_options);
|
|
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(at::kChar)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
const int quantization_scalar = 8 / num_bits;
|
|
const int compressed_vals = at::numel(input_vals) / quantization_scalar;
|
|
|
|
auto output = torch::empty({compressed_vals}, output_options);
|
|
const int elems_per_group = at::numel(input_vals) / groups;
|
|
|
|
launch_swizzled_quant((int8_t*)output.data_ptr(),
|
|
(float*)scales.data_ptr(),
|
|
(__half*)input_vals.data_ptr(),
|
|
num_bits,
|
|
quant_type,
|
|
groups,
|
|
elems_per_group,
|
|
pipeline_size,
|
|
nodes,
|
|
devices_per_node,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return {output, scales};
|
|
}
|
|
|
|
std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
|
|
at::Tensor& input_scales,
|
|
int in_groups,
|
|
int out_groups,
|
|
int num_bits,
|
|
quantize::Type quant_type,
|
|
int devices_per_node)
|
|
{
|
|
auto scales_options = at::TensorOptions()
|
|
.dtype(at::kFloat)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
|
|
auto scales = torch::empty({out_groups, scales_elems}, scales_options);
|
|
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(at::kChar)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
|
|
sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes
|
|
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
|
|
auto output = torch::empty(sz, output_options);
|
|
|
|
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
|
|
const int elems_per_out_group = elems_per_in_tensor / out_groups;
|
|
|
|
launch_dequant_reduce((int8_t*)output.data_ptr(),
|
|
(float*)scales.data_ptr(),
|
|
(const int8_t*)input_vals.data_ptr(),
|
|
(const float*)input_scales.data_ptr(),
|
|
devices_per_node,
|
|
num_bits,
|
|
quant_type,
|
|
out_groups,
|
|
elems_per_out_group,
|
|
elems_per_in_tensor,
|
|
in_groups / devices_per_node,
|
|
elems_per_in_group,
|
|
at::cuda::getCurrentCUDAStream());
|
|
return {output, scales};
|
|
}
|
|
|
|
std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
|
|
at::Tensor& input_scales,
|
|
at::Tensor& error_feedback,
|
|
float err_beta,
|
|
int in_groups,
|
|
int out_groups,
|
|
int num_bits,
|
|
quantize::Type quant_type,
|
|
int devices_per_node)
|
|
{
|
|
auto scales_options = at::TensorOptions()
|
|
.dtype(at::kFloat)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
|
|
|
|
auto scales = torch::empty({out_groups, scales_elems}, scales_options);
|
|
|
|
auto output_options = at::TensorOptions()
|
|
.dtype(at::kChar)
|
|
.layout(at::kStrided)
|
|
.device(at::kCUDA)
|
|
.requires_grad(false);
|
|
|
|
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
|
|
sz[sz.size() - 1] = sz.back() / devices_per_node;
|
|
|
|
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
|
|
|
|
auto output = torch::empty(sz, output_options);
|
|
|
|
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
|
|
const int elems_per_out_group = elems_per_in_tensor / out_groups;
|
|
|
|
launch_loco_dequant_reduce((int8_t*)output.data_ptr(),
|
|
(float*)scales.data_ptr(),
|
|
(const int8_t*)input_vals.data_ptr(),
|
|
(const float*)input_scales.data_ptr(),
|
|
devices_per_node,
|
|
num_bits,
|
|
quant_type,
|
|
out_groups,
|
|
elems_per_out_group,
|
|
elems_per_in_tensor,
|
|
in_groups / devices_per_node,
|
|
elems_per_in_group,
|
|
(__half2*)error_feedback.data_ptr(),
|
|
err_beta,
|
|
at::cuda::getCurrentCUDAStream());
|
|
|
|
return {output, scales};
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
{
|
|
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
|
|
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
|
|
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
|
|
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
|
|
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
|
|
m.def(
|
|
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
|
|
m.def("ds_sr_quantize_asym_fp32",
|
|
&ds_sr_quantize_asym<float>,
|
|
"DeepSpeed Quantize with fp32 (CUDA)");
|
|
m.def("ds_sr_quantize_asym_fp16",
|
|
&ds_sr_quantize_asym<__half>,
|
|
"DeepSpeed Quantize with fp16 (CUDA)");
|
|
pybind11::enum_<quantize::Type>(m, "QuantizationType")
|
|
.value("Symmetric", quantize::Type::Symmetric)
|
|
.value("Asymmetric", quantize::Type::Asymmetric)
|
|
.export_values();
|
|
m.def("quantize", &quantize_kernel);
|
|
m.def("dequantize", &dequantize<__half>);
|
|
m.def("dequantize_fp32", &dequantize<float>);
|
|
m.def("dequantize_int4_to_half_experimental",
|
|
&dequantize_int4_to_half_experimental,
|
|
"Dequantize int4 to half (experimental)");
|
|
m.def("dequantize_int8_to_half_experimental",
|
|
&dequantize_int8_to_half_experimental,
|
|
"Dequantize int8 to half (experimental)");
|
|
m.def("swizzle_quant", &ds_swizzle_quant);
|
|
m.def("quantized_reduction", &quantized_reduction);
|
|
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
|
|
m.def("loco_quantized_reduction",
|
|
&loco_quantized_reduction,
|
|
"LoCo Quantization and Reduction Kernel");
|
|
}
|