Files
DeepSpeed/csrc/quantization/pt_binding.cpp
xyxie 1b58ba5ec0 Merge LoCo with Zero++ (#6730)
### Integration of LoCo Method into ZeRO++

#### Overview
This PR introduces the integration of the **LoCo** method, as outlined
in [this paper](https://arxiv.org/abs/2407.04480), into the ZeRO++
framework of DeepSpeed. The key enhancement involves applying error
feedback compensation to 4-bit gradients before communication. This
approach ***improves pre-training loss outcomes without additional time
overhead***, though it requires extra GPU memory. The extent of this
memory increase depends on model size and training configuration.

#### Experimental Results
We conducted pre-training experiments using the Llama2 architecture,
adjusting the number of layers and hidden size. The experiments
included:
- **A smaller-scale model with 0.8B parameters trained on 30B tokens**.
- **A larger-scale model with 8B parameters trained on 5B tokens**.

The training data was sampled from **Redpajama-V2**.
<p align="center">
<img
src="https://github.com/user-attachments/assets/e7db9487-728c-4a17-9806-c15afa12f62e"
width="49%" />
<img
src="https://github.com/user-attachments/assets/3efec895-b71d-43ab-b5ce-65468ba8b9f1"
width="49%" />
</p>

**Findings**:
- **Smaller Models (0.8B parameters)**: Significant gains were observed
when applying the LoCo method.
- **Larger Models (8B parameters)**: The gains were present but less
pronounced. This could be due to:
  1. Relatively smaller data volume.
2. Lower pre-training loss for larger models, making significant
improvements harder to achieve.

However, even a smaller pre-training loss gap in larger models can
translate to meaningful gains in downstream tasks.

#### Example Script
For reference, the
[run.sh](https://github.com/user-attachments/files/17679552/zeroplus-7b3.zip)
script used for the 8B parameter, 5B tokens experiment is attached. The
experiment was conducted using the **DeepSpeed-Megatron** platform.



#### Acknowledgments
Special thanks to cc @GuanhuaWang for ongoing communication and guidance
throughout this work.

---

We appreciate your consideration of this PR and welcome any feedback or
questions!

---------

Co-authored-by: ChuanxinTang <tangchuanxin.chn@gmail.com>
Co-authored-by: root <pan.jiachun@outlook.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
2024-12-10 10:31:11 -08:00

405 lines
17 KiB
C++

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "quantization.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
int groups,
int numBits,
quantize::Type quantType)
{
auto dtype = at::kFloat;
auto params_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
auto params = torch::empty({groups, param_elems}, params_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = input_vals.sizes().vec();
output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_quant((int8_t*)output.data_ptr(),
(float*)params.data_ptr(),
(__half*)input_vals.data_ptr(),
groups,
elems_per_group,
numBits,
quantType,
at::cuda::getCurrentCUDAStream());
return {output, params};
}
template <typename T>
at::Tensor dequantize(at::Tensor& quantized_data,
at::Tensor& params,
int groups,
int num_bits,
quantize::Type quant_type)
{
auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
auto output_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = quantized_data.sizes().vec();
output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int total_elems = at::numel(output);
const int elems_per_group = total_elems / groups;
launch_dequantize_kernel((T*)output.data_ptr(),
(const int8_t*)quantized_data.data_ptr(),
(const float*)params.data_ptr(),
quant_type,
num_bits,
elems_per_group,
total_elems,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dequantize_int4_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);
launch_dequantize_int4_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
at::Tensor& scale_buffer,
at::Tensor& min_val_buffer,
int num_group,
int group_size)
{
auto output_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto output = torch::empty({num_group, group_size}, output_options);
launch_dequantize_int8_to_half_experimental((uint8_t*)data_in.data_ptr(),
(half*)output.data_ptr(),
(half*)scale_buffer.data_ptr(),
(half*)min_val_buffer.data_ptr(),
num_group,
group_size,
at::cuda::getCurrentCUDAStream());
return output;
}
std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;
auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_loco_swizzled_quant(reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;
auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_swizzled_quant((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(__half*)input_vals.data_ptr(),
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({out_groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
auto output = torch::empty(sz, output_options);
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;
launch_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({out_groups, scales_elems}, scales_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
auto output = torch::empty(sz, output_options);
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;
launch_loco_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(__half2*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream());
return {output, scales};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
pybind11::enum_<quantize::Type>(m, "QuantizationType")
.value("Symmetric", quantize::Type::Symmetric)
.value("Asymmetric", quantize::Type::Asymmetric)
.export_values();
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
m.def("dequantize_int4_to_half_experimental",
&dequantize_int4_to_half_experimental,
"Dequantize int4 to half (experimental)");
m.def("dequantize_int8_to_half_experimental",
&dequantize_int8_to_half_experimental,
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction",
&loco_quantized_reduction,
"LoCo Quantization and Reduction Kernel");
}