mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
CPUAdam fp16 and bf16 support (#5409)
Hi. Please review the following changes I added support for BF16 to cpu adam. BF16, FP16 and float are supported at compilation time. the correct template is called at runtime according to input params dtype. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -5,55 +5,38 @@
|
||||
|
||||
#include "cpu_adagrad.h"
|
||||
#include <torch/extension.h>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adagrad_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float step_size = -1 * _alpha;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = grads[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0) { grad = param * _weight_decay + grad; }
|
||||
@ -64,58 +47,30 @@ void Adagrad_Optimizer::Step_1(float* _params,
|
||||
grad += _eps;
|
||||
grad = momentum / grad;
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
// STORE UPDATE TERM TO GRAD'S MEMORY
|
||||
grads[k] = grad * step_size;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_adagrad_optimizer(int optimizer_id,
|
||||
@ -149,25 +104,77 @@ int create_adagrad_optimizer(int optimizer_id,
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adagrad_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(
|
||||
&rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg_sq),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Adagrad_Optimizer>, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Adagrad_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg_sq.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Adagrad optimizer with param type "s +
|
||||
c10::toString(params_type) + " and state type "s +
|
||||
c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg_sq.data_ptr(), param_size);
|
||||
}
|
||||
|
||||
int ds_adagrad_step(int optimizer_id,
|
||||
@ -183,58 +190,13 @@ int ds_adagrad_step(int optimizer_id,
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adagrad_Optimizer> opt =
|
||||
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step);
|
||||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
invoke(opt, params_c, grads_c, exp_avg_sq_c, params_c.numel());
|
||||
|
||||
int ds_adagrad_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adagrad_Optimizer> opt =
|
||||
std::static_pointer_cast<Adagrad_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step);
|
||||
opt->update_state(lr, epsilon, weight_decay);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -248,9 +210,6 @@ int destroy_adagrad_optimizer(int optimizer_id)
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)");
|
||||
m.def("adagrad_update_copy",
|
||||
&ds_adagrad_step_plus_copy,
|
||||
"DeepSpeed CPU Adagrad update and param copy (C++)");
|
||||
m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)");
|
||||
m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)");
|
||||
}
|
||||
|
@ -8,9 +8,6 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
||||
m.def("adam_update_copy",
|
||||
&ds_adam_step_plus_copy,
|
||||
"DeepSpeed CPU Adam update and param copy (C++)");
|
||||
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
||||
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
||||
}
|
||||
|
@ -5,42 +5,29 @@
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_adam.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
@ -48,26 +35,15 @@ void Adam_Optimizer::Step_1(float* _params,
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float variance = _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
|
||||
@ -83,66 +59,31 @@ void Adam_Optimizer::Step_1(float* _params,
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
|
||||
param = grad * step_size + param;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_adam_optimizer(int optimizer_id,
|
||||
@ -185,33 +126,86 @@ int create_adam_optimizer(int optimizer_id,
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size,
|
||||
_params,
|
||||
grads,
|
||||
_exp_avg,
|
||||
_exp_avg_sq,
|
||||
_param_size,
|
||||
dev_params,
|
||||
half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Adam_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg,
|
||||
void* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg),
|
||||
(ds_state_precision_t*)(_exp_avg_sq),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Adam_Optimizer>, void*, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Adam_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Adam optimizer with param type "s + c10::toString(params_type) +
|
||||
" and state type "s + c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt,
|
||||
params.data_ptr(),
|
||||
grads.data_ptr(),
|
||||
exp_avg.data_ptr(),
|
||||
exp_avg_sq.data_ptr(),
|
||||
param_size);
|
||||
}
|
||||
|
||||
int ds_adam_step(int optimizer_id,
|
||||
@ -232,75 +226,13 @@ int ds_adam_step(int optimizer_id,
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
invoke(opt, params_c, grads_c, exp_avg_c, exp_avg_sq_c, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& device_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto device_params_c = device_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* device_params_ptr = (ds_half_precision_t*)device_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Adam_Optimizer> opt =
|
||||
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
exp_avg_sq_ptr,
|
||||
params_c.numel(),
|
||||
device_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -1,44 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "custom_cuda_layers.h"
|
||||
|
||||
__global__ void param_update_kernel(const float* input, __half* output, int size)
|
||||
{
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (id < size) { output[id] = (__half)input[id]; }
|
||||
}
|
||||
|
||||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
|
||||
dim3 grid_dim((size - 1) / threads + 1);
|
||||
dim3 block_dim(threads);
|
||||
|
||||
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
|
||||
}
|
||||
|
||||
__global__ void param_update_kernel_half(const float* input, __half* output, int size)
|
||||
{
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
__half2* output_cast = reinterpret_cast<__half2*>(output);
|
||||
if (id < size) {
|
||||
float input_f = input[id];
|
||||
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
|
||||
output_cast[id] = *input_h;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream)
|
||||
{
|
||||
int threads = 1024;
|
||||
size /= 2;
|
||||
dim3 grid_dim((size - 1) / threads + 1);
|
||||
dim3 block_dim(threads);
|
||||
|
||||
param_update_kernel_half<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
|
||||
}
|
@ -9,84 +9,35 @@
|
||||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg_sq, \
|
||||
size_t _param_size);
|
||||
|
||||
class Adagrad_Optimizer {
|
||||
public:
|
||||
Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
|
||||
: _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Adagrad_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Adagrad_Optimizer() {}
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step)
|
||||
{
|
||||
_step++;
|
||||
@ -107,29 +58,22 @@ private:
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
bool _buf_index;
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
@ -145,24 +89,19 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + i, half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, grads + i, false);
|
||||
simd_load<span>(momentum_4, grads + i);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + i, half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
if (_weight_decay > 0) { simd_fma<span>(grad_4, param_4, weight_decay4, grad_4); }
|
||||
|
||||
@ -172,38 +111,9 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
simd_div<span>(grad_4, momentum_4, grad_4);
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + i, param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
|
@ -13,29 +13,13 @@
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
float* _exp_avg_sq, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg, \
|
||||
ds_state_precision_t* _exp_avg_sq, \
|
||||
size_t _param_size);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
@ -55,56 +39,21 @@ public:
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Adam_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
@ -154,32 +103,24 @@ private:
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
float* _exp_avg_sq,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
ds_state_precision_t* _exp_avg_sq,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
@ -212,24 +153,19 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
simd_load<span>(momentum_4, _exp_avg + i);
|
||||
|
||||
AVX_Data variance_4[span];
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i, false);
|
||||
simd_load<span>(variance_4, _exp_avg_sq + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
|
||||
@ -250,39 +186,10 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
|
||||
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg + i, momentum_4);
|
||||
simd_store<span>(_exp_avg_sq + i, variance_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
@ -310,18 +217,4 @@ int ds_adam_step(int optimizer_id,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq);
|
||||
|
||||
int ds_adam_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float epsilon,
|
||||
float weight_decay,
|
||||
bool bias_correction,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& exp_avg_sq,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_adam_optimizer(int optimizer_id);
|
||||
|
@ -13,28 +13,12 @@
|
||||
#include <cassert>
|
||||
#include "simd.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cuda.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
typedef __half ds_half_precision_t;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
#include "acl/acl.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
typedef c10::Half ds_half_precision_t;
|
||||
#else
|
||||
#include <cmath>
|
||||
typedef unsigned short ds_half_precision_t;
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(float* _params, \
|
||||
float* grads, \
|
||||
float* _exp_avg, \
|
||||
size_t _param_size, \
|
||||
ds_half_precision_t* dev_param = nullptr, \
|
||||
bool half_precision = false);
|
||||
#define STEP(SPAN) \
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
||||
ds_params_percision_t* grads, \
|
||||
ds_state_precision_t* _exp_avg, \
|
||||
size_t _param_size);
|
||||
|
||||
class Lion_Optimizer {
|
||||
public:
|
||||
@ -44,55 +28,21 @@ public:
|
||||
float weight_decay = 0)
|
||||
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _weight_decay(weight_decay), _step(0)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_streams[0] = TrainingContext::Instance().GetCurrentStream();
|
||||
_streams[1] = TrainingContext::Instance().GetNewStream();
|
||||
_buf_index = false;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
|
||||
aclrtMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
|
||||
|
||||
_buf_index = false;
|
||||
#endif
|
||||
}
|
||||
~Lion_Optimizer()
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
cudaFreeHost(_doubled_buffer[0]);
|
||||
cudaFreeHost(_doubled_buffer[1]);
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
aclrtFreeHost(_doubled_buffer[0]);
|
||||
aclrtFreeHost(_doubled_buffer[1]);
|
||||
#endif
|
||||
}
|
||||
~Lion_Optimizer() {}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t param_size,
|
||||
ds_half_precision_t* dev_param = nullptr,
|
||||
bool half_precision = false);
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t param_size);
|
||||
#endif
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
inline void SynchronizeStreams()
|
||||
{
|
||||
for (int i = 0; i < 2; i++) aclrtSynchronizeStream(_streams[i].stream());
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2)
|
||||
{
|
||||
_step++;
|
||||
@ -114,31 +64,23 @@ private:
|
||||
float _betta2;
|
||||
float _weight_decay;
|
||||
size_t _step;
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
float* _doubled_buffer[2];
|
||||
cudaStream_t _streams[2];
|
||||
bool _buf_index;
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
float* _doubled_buffer[2];
|
||||
c10_npu::NPUStream _streams[2] = {c10_npu::getCurrentNPUStream(),
|
||||
c10_npu::getNPUStreamFromPool()};
|
||||
bool _buf_index;
|
||||
#endif
|
||||
};
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
template <int span>
|
||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
#if !defined(__AVX512__)
|
||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
size_t new_rounded_size = 0;
|
||||
int rshft = half_precision ? 1 : 0;
|
||||
|
||||
constexpr float neg1 = -1.0f;
|
||||
AVX_Data neg1_4;
|
||||
@ -169,21 +111,17 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
|
||||
AVX_Data grad_4[span];
|
||||
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
|
||||
simd_load<span>(grad_4, grads + i);
|
||||
|
||||
AVX_Data momentum_4[span];
|
||||
simd_load<span>(momentum_4, _exp_avg + i, false);
|
||||
simd_load<span>(momentum_4, _exp_avg + i);
|
||||
|
||||
AVX_Data param_4[span];
|
||||
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
|
||||
simd_load<span>(param_4, _params + i);
|
||||
|
||||
AVX_Data tmp_4[span];
|
||||
|
||||
@ -201,38 +139,9 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||
simd_mul<span>(momentum_4, momentum_4, betta2_4);
|
||||
simd_fma<span>(momentum_4, grad_4, betta2_minus1_4, momentum_4);
|
||||
|
||||
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
|
||||
}
|
||||
#endif
|
||||
simd_store<span>(_exp_avg + i, momentum_4, false);
|
||||
simd_store<span>(_params + i, param_4);
|
||||
simd_store<span>(_exp_avg + i, momentum_4);
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
if (half_precision)
|
||||
launch_param_update_half(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
else
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
if (half_precision) memcpy_size /= 2;
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
*rounded_size = new_rounded_size;
|
||||
}
|
||||
@ -255,15 +164,4 @@ int ds_lion_step(int optimizer_id,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg);
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params);
|
||||
|
||||
int destroy_lion_optimizer(int optimizer_id);
|
||||
|
@ -272,9 +272,6 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
|
||||
int cols,
|
||||
cudaStream_t stream);
|
||||
|
||||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
|
||||
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
|
||||
|
||||
void launch_token_sort(int32_t* indices,
|
||||
int layers,
|
||||
int batch_size,
|
||||
|
@ -13,6 +13,19 @@
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
|
||||
template <typename T>
|
||||
inline T readAs(const void* src)
|
||||
{
|
||||
T res;
|
||||
std::memcpy(&res, src, sizeof(T));
|
||||
return res;
|
||||
}
|
||||
template <typename T>
|
||||
inline void writeAs(void* dst, const T& val)
|
||||
{
|
||||
std::memcpy(dst, &val, sizeof(T));
|
||||
}
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
|
||||
#if defined(__AVX512__)
|
||||
@ -30,11 +43,52 @@
|
||||
#define SIMD_XOR(x, y) _mm512_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 16
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm512_storeu_ps(x, d))
|
||||
static __m512 load_16_bf16_as_f32(const void* data)
|
||||
{
|
||||
__m256i a = readAs<__m256i>(data); // use memcpy to avoid aliasing
|
||||
__m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
|
||||
__m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
|
||||
// 16 bits (representing bf16->f32)
|
||||
return readAs<__m512>(&c); // use memcpy to avoid aliasing
|
||||
}
|
||||
|
||||
static void store_16_f32_as_bf16_nearest(__m512 v, void* data)
|
||||
{
|
||||
__m512i u32 = readAs<__m512i>(&v);
|
||||
|
||||
// flow assuming non-nan:
|
||||
|
||||
// uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
||||
__m512i b = _mm512_srli_epi32(u32, 16);
|
||||
__m512i lsb_mask = _mm512_set1_epi32(0x00000001);
|
||||
__m512i c = _mm512_and_si512(b, lsb_mask);
|
||||
__m512i bias_constant = _mm512_set1_epi32(0x00007fff);
|
||||
__m512i rounding_bias = _mm512_add_epi32(c, bias_constant);
|
||||
|
||||
// uint16_t res = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||
__m512i d = _mm512_add_epi32(u32, rounding_bias);
|
||||
__m512i e = _mm512_srli_epi32(d, 16);
|
||||
__m256i non_nan_res = _mm512_cvtusepi32_epi16(e);
|
||||
|
||||
// handle nan (exp is all 1s and mantissa != 0)
|
||||
// if ((x & 0x7fffffffU) > 0x7f800000U)
|
||||
__m512i mask_out_sign = _mm512_set1_epi32(0x7fffffff);
|
||||
__m512i non_sign_bits = _mm512_and_si512(u32, mask_out_sign);
|
||||
__m512i nan_threshold = _mm512_set1_epi32(0x7f800000);
|
||||
__mmask16 nan_mask = _mm512_cmp_epi32_mask(non_sign_bits, nan_threshold, _MM_CMPINT_GT);
|
||||
|
||||
// mix in results with nans as needed
|
||||
__m256i nans = _mm256_set1_epi16(0x7fc0);
|
||||
__m256i res = _mm256_mask_mov_epi16(non_nan_res, nan_mask, nans);
|
||||
|
||||
writeAs(data, res);
|
||||
}
|
||||
#define SIMD_LOAD_BF16(x) load_16_bf16_as_f32(x)
|
||||
#define SIMD_STORE_BF16(x, d) store_16_f32_as_bf16_nearest(d, x)
|
||||
|
||||
#define SIMD_LOAD_FP16(x) _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x)))
|
||||
#define SIMD_STORE_FP16(x, d) \
|
||||
_mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#define INTV __m256i
|
||||
#elif defined(__AVX256__)
|
||||
@ -52,11 +106,11 @@
|
||||
#define SIMD_XOR(x, y) _mm256_xor_ps(x, y)
|
||||
#define SIMD_WIDTH 8
|
||||
|
||||
#define SIMD_LOAD2(x, h) \
|
||||
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) : _mm256_loadu_ps(x))
|
||||
#define SIMD_STORE2(x, d, h) \
|
||||
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
|
||||
: _mm256_storeu_ps(x, d))
|
||||
#define SIMD_LOAD_BF16(x) static_assert(false && "AVX256 does not support BFloat16")
|
||||
#define SIMD_STORE_BF16(x, d) static_assert(false && "AVX256 does not support BFloat16")
|
||||
#define SIMD_LOAD_FP16(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x))
|
||||
#define SIMD_STORE_FP16(x, d) \
|
||||
_mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#define INTV __m128i
|
||||
#endif
|
||||
@ -70,20 +124,66 @@ union AVX_Data {
|
||||
// float data_f[16];
|
||||
};
|
||||
|
||||
template <int span>
|
||||
inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::Half>, void> simd_store(T* dst,
|
||||
AVX_Data* src)
|
||||
{
|
||||
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); }
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE_FP16((float*)(dst + width * i), src[i].data); }
|
||||
}
|
||||
template <int span>
|
||||
inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::BFloat16>, void> simd_store(T* dst,
|
||||
AVX_Data* src)
|
||||
{
|
||||
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
|
||||
#ifdef __AVX512__
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); }
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE_BF16((float*)(dst + width * i), src[i].data); }
|
||||
#else
|
||||
throw std::runtime_error("AVX512 required for BFloat16");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, float>, void> simd_store(T* dst, AVX_Data* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { SIMD_STORE(dst + width * i, src[i].data); }
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::Half>, void> simd_load(AVX_Data* dst,
|
||||
T* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_FP16((float*)(src + width * i)); }
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, c10::BFloat16>, void> simd_load(AVX_Data* dst,
|
||||
T* src)
|
||||
{
|
||||
#ifdef __AVX512__
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD_BF16((float*)(src + width * i)); }
|
||||
#else
|
||||
throw std::runtime_error("AVX512 required for BFloat16");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int span, typename T>
|
||||
inline typename std::enable_if_t<std::is_same_v<T, float>, void> simd_load(AVX_Data* dst, T* src)
|
||||
{
|
||||
size_t width = SIMD_WIDTH;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD(src + width * i); }
|
||||
}
|
||||
|
||||
template <int span>
|
||||
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
|
||||
{
|
||||
|
@ -8,9 +8,6 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("lion_update", &ds_lion_step, "DeepSpeed CPU Lion update (C++)");
|
||||
m.def("lion_update_copy",
|
||||
&ds_lion_step_plus_copy,
|
||||
"DeepSpeed CPU Lion update and param copy (C++)");
|
||||
m.def("create_lion", &create_lion_optimizer, "DeepSpeed CPU Lion (C++)");
|
||||
m.def("destroy_lion", &destroy_lion_optimizer, "DeepSpeed CPU Lion destroy (C++)");
|
||||
}
|
||||
|
@ -6,34 +6,28 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include "cpu_lion.h"
|
||||
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cublas_v2.h"
|
||||
#include "cuda.h"
|
||||
#include "curand.h"
|
||||
#include "custom_cuda_layers.h"
|
||||
#endif
|
||||
|
||||
using namespace std::string_literals;
|
||||
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Lion_Optimizer::Step_1(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_1(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
@ -41,26 +35,15 @@ void Lion_Optimizer::Step_1(float* _params,
|
||||
|
||||
float alpha = _alpha;
|
||||
float after_decay = 1 - alpha * _weight_decay;
|
||||
ds_half_precision_t* grads_cast_h;
|
||||
ds_half_precision_t* params_cast_h;
|
||||
if (half_precision) {
|
||||
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
|
||||
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
|
||||
}
|
||||
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); }
|
||||
#endif
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
float param = half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float grad = (float)grads[k];
|
||||
float param = (float)_params[k];
|
||||
float momentum = _exp_avg[k];
|
||||
float tmp = momentum * _betta1;
|
||||
tmp = grad * betta1_minus1 + tmp;
|
||||
@ -74,56 +57,28 @@ void Lion_Optimizer::Step_1(float* _params,
|
||||
}
|
||||
momentum = momentum * _betta2;
|
||||
momentum = grad * betta2_minus1 + momentum;
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
|
||||
#endif
|
||||
if (half_precision)
|
||||
params_cast_h[k] = (ds_half_precision_t)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
_params[k] = param;
|
||||
_exp_avg[k] = momentum;
|
||||
}
|
||||
#if defined(__ENABLE_CUDA__)
|
||||
if (dev_params) {
|
||||
launch_param_update(
|
||||
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#elif defined(__ENABLE_CANN__)
|
||||
if (dev_params) {
|
||||
size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]);
|
||||
aclrtMemcpy(dev_params + t,
|
||||
memcpy_size,
|
||||
_doubled_buffer[_buf_index],
|
||||
memcpy_size,
|
||||
aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
|
||||
_buf_index = !_buf_index;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_4(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_4(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
int create_lion_optimizer(int optimizer_id,
|
||||
@ -162,24 +117,76 @@ int create_lion_optimizer(int optimizer_id,
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Lion_Optimizer::Step_8(float* _params,
|
||||
float* grads,
|
||||
float* _exp_avg,
|
||||
size_t _param_size,
|
||||
ds_half_precision_t* dev_params,
|
||||
bool half_precision)
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void Lion_Optimizer::Step_8(ds_params_percision_t* _params,
|
||||
ds_params_percision_t* grads,
|
||||
ds_state_precision_t* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__AVX512__) or defined(__AVX256__)
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size, dev_params, half_precision);
|
||||
Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _param_size);
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((_params + rounded_size),
|
||||
(grads + rounded_size),
|
||||
(_exp_avg + rounded_size),
|
||||
(_param_size - rounded_size),
|
||||
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params),
|
||||
half_precision);
|
||||
(_param_size - rounded_size));
|
||||
}
|
||||
|
||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
||||
void step_invoker(std::shared_ptr<Lion_Optimizer> opt,
|
||||
void* _params,
|
||||
void* grads,
|
||||
void* _exp_avg,
|
||||
size_t _param_size)
|
||||
{
|
||||
opt->Step_8((ds_params_percision_t*)(_params),
|
||||
(ds_params_percision_t*)(grads),
|
||||
(ds_state_precision_t*)(_exp_avg),
|
||||
_param_size);
|
||||
}
|
||||
|
||||
std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
||||
std::function<void(std::shared_ptr<Lion_Optimizer>, void*, void*, void*, size_t)>>
|
||||
invokers;
|
||||
|
||||
// Fill map with template functions for each type
|
||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
||||
void create_invoker()
|
||||
{
|
||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
||||
}
|
||||
struct InvokerInitializer {
|
||||
InvokerInitializer()
|
||||
{
|
||||
create_invoker<c10::Half, float>();
|
||||
create_invoker<c10::Half, c10::Half>();
|
||||
create_invoker<c10::BFloat16, float>();
|
||||
create_invoker<c10::BFloat16, c10::BFloat16>();
|
||||
create_invoker<float, float>();
|
||||
}
|
||||
} _invoker_initializer;
|
||||
|
||||
void invoke(std::shared_ptr<Lion_Optimizer> opt,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
size_t param_size)
|
||||
{
|
||||
c10::ScalarType params_type = at::typeMetaToScalarType(params.options().dtype());
|
||||
c10::ScalarType state_type = at::typeMetaToScalarType(exp_avg.options().dtype());
|
||||
|
||||
auto it = invokers.find(std::tuple(params_type, state_type));
|
||||
if (it == invokers.end()) {
|
||||
throw std::runtime_error("Lion optimizer with param type "s + c10::toString(params_type) +
|
||||
" and state type "s + c10::toString(state_type) +
|
||||
" is not supported on current hardware"s);
|
||||
}
|
||||
|
||||
it->second(opt, params.data_ptr(), grads.data_ptr(), exp_avg.data_ptr(), param_size);
|
||||
}
|
||||
|
||||
int ds_lion_step(int optimizer_id,
|
||||
@ -196,67 +203,13 @@ int ds_lion_step(int optimizer_id,
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
|
||||
// assert(params.options().dtype() == grads.options().dtype());
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
nullptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
invoke(opt, params_c, grads_c, exp_avg_c, params_c.numel());
|
||||
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
opt->SynchronizeStreams();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ds_lion_step_plus_copy(int optimizer_id,
|
||||
size_t step,
|
||||
float lr,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float weight_decay,
|
||||
torch::Tensor& params,
|
||||
torch::Tensor& grads,
|
||||
torch::Tensor& exp_avg,
|
||||
torch::Tensor& gpu_params)
|
||||
{
|
||||
#if defined(__ENABLE_CUDA__) or defined(__ENABLE_CANN__)
|
||||
auto params_c = params.contiguous();
|
||||
auto gpu_params_c = gpu_params.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
|
||||
float* params_ptr = (float*)params_c.data_ptr();
|
||||
float* grads_ptr = (float*)grads_c.data_ptr();
|
||||
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
|
||||
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
|
||||
|
||||
std::shared_ptr<Lion_Optimizer> opt =
|
||||
std::static_pointer_cast<Lion_Optimizer>(s_optimizers[optimizer_id]);
|
||||
opt->IncrementStep(step, beta1, beta2);
|
||||
opt->update_state(lr, weight_decay);
|
||||
opt->Step_8(params_ptr,
|
||||
grads_ptr,
|
||||
exp_avg_ptr,
|
||||
params_c.numel(),
|
||||
gpu_params_ptr,
|
||||
(params.options().dtype() == at::kHalf));
|
||||
|
||||
opt->SynchronizeStreams();
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, fp16_param_groups=None):
|
||||
def step(self, closure=None):
|
||||
"""Update the model parameters.
|
||||
|
||||
.. note::
|
||||
@ -46,8 +46,6 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
|
||||
Args:
|
||||
closure (callable, optional): closure to compute the loss.
|
||||
Defaults to ``None``.
|
||||
fp16_param_groups: FP16 GPU parameters to update. Performing the
|
||||
copy here reduces communication time. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
loss: if ``closure`` is provided. Otherwise ``None``.
|
||||
@ -94,16 +92,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
|
||||
sparse_exp_avg_sq.values())
|
||||
p[sparse_param.indices()] = sparse_param.values()
|
||||
state['exp_avg_sq'][sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
|
||||
if fp16_param_groups is not None:
|
||||
fp16_param_groups[group_id][param_id][sparse_param.indices()] = sparse_param.values()
|
||||
else:
|
||||
if fp16_param_groups is not None:
|
||||
self.ds_opt_adagrad.adagrad_update_copy(self.opt_id, state['step'], group['lr'], group['eps'],
|
||||
group['weight_decay'], p.data, p.grad.data,
|
||||
state['exp_avg_sq'],
|
||||
fp16_param_groups[group_id][param_id].data)
|
||||
else:
|
||||
self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
|
||||
group['weight_decay'], p.data, p.grad.data,
|
||||
state['exp_avg_sq'])
|
||||
self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg_sq'])
|
||||
return loss
|
||||
|
@ -107,7 +107,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, fp16_param_groups=None):
|
||||
def step(self, closure=None):
|
||||
"""Update the model parameters.
|
||||
|
||||
.. note::
|
||||
@ -119,8 +119,6 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
Args:
|
||||
closure (callable, optional): closure to compute the loss.
|
||||
Defaults to ``None``.
|
||||
fp16_param_groups: FP16 GPU parameters to update. Performing the
|
||||
copy here reduces communication time. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
loss: if ``closure`` is provided. Otherwise ``None``.
|
||||
@ -134,13 +132,6 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
# intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
# converting the fp16 params to a group of parameter
|
||||
if type(fp16_param_groups) is list:
|
||||
if type(fp16_param_groups[0]) is not list:
|
||||
fp16_param_groups = [fp16_param_groups]
|
||||
elif fp16_param_groups is not None:
|
||||
fp16_param_groups = [[fp16_param_groups]]
|
||||
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
for param_id, p in enumerate(group['params']):
|
||||
|
||||
@ -169,13 +160,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
if fp16_param_groups is not None:
|
||||
self.ds_opt_adam.adam_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['eps'], group['weight_decay'], group['bias_correction'],
|
||||
p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'],
|
||||
fp16_param_groups[group_id][param_id].data)
|
||||
else:
|
||||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'])
|
||||
return loss
|
||||
|
@ -69,7 +69,7 @@ class DeepSpeedCPULion(torch.optim.Optimizer):
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, fp16_param_groups=None):
|
||||
def step(self, closure=None):
|
||||
"""Update the model parameters.
|
||||
|
||||
.. note::
|
||||
@ -81,8 +81,6 @@ class DeepSpeedCPULion(torch.optim.Optimizer):
|
||||
Args:
|
||||
closure (callable, optional): closure to compute the loss.
|
||||
Defaults to ``None``.
|
||||
fp16_param_groups: FP16 GPU parameters to update. Performing the
|
||||
copy here reduces communication time. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
loss: if ``closure`` is provided. Otherwise ``None``.
|
||||
@ -96,13 +94,6 @@ class DeepSpeedCPULion(torch.optim.Optimizer):
|
||||
# intended device for step
|
||||
device = torch.device('cpu')
|
||||
|
||||
# converting the fp16 params to a group of parameter
|
||||
if type(fp16_param_groups) is list:
|
||||
if type(fp16_param_groups[0]) is not list:
|
||||
fp16_param_groups = [fp16_param_groups]
|
||||
elif fp16_param_groups is not None:
|
||||
fp16_param_groups = [[fp16_param_groups]]
|
||||
|
||||
for group_id, group in enumerate(self.param_groups):
|
||||
for param_id, p in enumerate(group['params']):
|
||||
|
||||
@ -131,11 +122,6 @@ class DeepSpeedCPULion(torch.optim.Optimizer):
|
||||
state['step'] += 1
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
if fp16_param_groups is not None:
|
||||
self.ds_opt_lion.lion_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg'],
|
||||
fp16_param_groups[group_id][param_id].data)
|
||||
else:
|
||||
self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg'])
|
||||
self.ds_opt_lion.lion_update(self.opt_id, state['step'], group['lr'], beta1, beta2,
|
||||
group['weight_decay'], p.data, p.grad.data, state['exp_avg'])
|
||||
return loss
|
||||
|
@ -545,6 +545,7 @@ class OpBuilder(ABC):
|
||||
nvcc_args.append("-DBF16_AVAILABLE")
|
||||
nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
|
||||
nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
|
||||
nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
|
||||
|
||||
if self.is_rocm_pytorch():
|
||||
cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
|
||||
|
@ -30,7 +30,11 @@ class CPUOpBuilder(OpBuilder):
|
||||
return cpp_ext
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3', '-g', '-Wno-reorder']
|
||||
args = ['-O3', '-g', '-Wno-reorder']
|
||||
CPU_ARCH = self.cpu_arch()
|
||||
SIMD_WIDTH = self.simd_width()
|
||||
args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH]
|
||||
return args
|
||||
|
||||
def libraries_args(self):
|
||||
return []
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from .builder import TorchCPUOpBuilder
|
||||
|
||||
|
||||
@ -18,26 +17,11 @@ class CPUAdagradBuilder(TorchCPUOpBuilder):
|
||||
return f'deepspeed.ops.adagrad.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
if self.build_for_cpu:
|
||||
return ['csrc/adagrad/cpu_adagrad.cpp']
|
||||
|
||||
return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu']
|
||||
return ['csrc/adagrad/cpu_adagrad.cpp']
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
if self.build_for_cpu:
|
||||
return args
|
||||
|
||||
if not self.is_rocm_pytorch():
|
||||
args += ['curand']
|
||||
return args
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
if self.build_for_cpu:
|
||||
CUDA_INCLUDE = []
|
||||
elif not self.is_rocm_pytorch():
|
||||
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
||||
else:
|
||||
CUDA_INCLUDE = []
|
||||
return ['csrc/includes'] + CUDA_INCLUDE
|
||||
return ['csrc/includes']
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from .builder import TorchCPUOpBuilder
|
||||
|
||||
|
||||
@ -18,27 +17,11 @@ class CPUAdamBuilder(TorchCPUOpBuilder):
|
||||
return f'deepspeed.ops.adam.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
if self.build_for_cpu:
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
|
||||
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
if self.build_for_cpu:
|
||||
return args
|
||||
|
||||
if not self.is_rocm_pytorch():
|
||||
args += ['curand']
|
||||
|
||||
return args
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
if self.build_for_cpu:
|
||||
CUDA_INCLUDE = []
|
||||
elif not self.is_rocm_pytorch():
|
||||
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
||||
else:
|
||||
CUDA_INCLUDE = []
|
||||
return ['csrc/includes'] + CUDA_INCLUDE
|
||||
return ['csrc/includes']
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
from .builder import TorchCPUOpBuilder
|
||||
|
||||
|
||||
@ -18,31 +17,11 @@ class CPULionBuilder(TorchCPUOpBuilder):
|
||||
return f'deepspeed.ops.lion.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
if self.build_for_cpu:
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
|
||||
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
|
||||
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
if self.build_for_cpu:
|
||||
return args
|
||||
|
||||
if not self.is_rocm_pytorch():
|
||||
args += ['curand']
|
||||
|
||||
return args
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
if self.build_for_cpu:
|
||||
CUDA_INCLUDE = []
|
||||
elif not self.is_rocm_pytorch():
|
||||
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
||||
else:
|
||||
CUDA_INCLUDE = [
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
|
||||
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
|
||||
]
|
||||
return ['csrc/includes'] + CUDA_INCLUDE
|
||||
return ['csrc/includes']
|
||||
|
@ -31,7 +31,11 @@ class CPUOpBuilder(OpBuilder):
|
||||
return cpp_ext
|
||||
|
||||
def cxx_args(self):
|
||||
return ['-O3', '-g', '-Wno-reorder']
|
||||
args = ['-O3', '-g', '-Wno-reorder']
|
||||
CPU_ARCH = self.cpu_arch()
|
||||
SIMD_WIDTH = self.simd_width()
|
||||
args += [CPU_ARCH, '-fopenmp', SIMD_WIDTH]
|
||||
return args
|
||||
|
||||
def libraries_args(self):
|
||||
return []
|
||||
|
@ -20,11 +20,6 @@ class CPUAdamBuilder(CPUOpBuilder):
|
||||
def sources(self):
|
||||
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
|
||||
|
||||
def cxx_args(self):
|
||||
args = super().cxx_args()
|
||||
args += ['-DENABLE_BFLOAT16']
|
||||
return args
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
return args
|
||||
|
@ -6,12 +6,10 @@
|
||||
import torch
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
import time
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
device = 'cpu'
|
||||
model_size = 1 * 1024**3
|
||||
param = torch.nn.Parameter(torch.ones(model_size, device=device))
|
||||
param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, device=get_accelerator().device_name(0)))
|
||||
|
||||
optimizer = DeepSpeedCPUAdam([param])
|
||||
#torch.set_num_threads(128)
|
||||
@ -19,7 +17,7 @@ param.grad = torch.ones(model_size, device=device)
|
||||
avg = 0
|
||||
for i in range(100):
|
||||
start = time.time()
|
||||
optimizer.step(fp16_param_groups=[param_fp16])
|
||||
optimizer.step()
|
||||
stop = time.time()
|
||||
avg += (stop - start)
|
||||
param.grad = torch.ones(model_size, device=device) * 2
|
||||
|
@ -82,8 +82,12 @@ def set_accelerator_visible():
|
||||
if match:
|
||||
num_accelerators += 1
|
||||
elif get_accelerator().device_name() == 'hpu':
|
||||
hl_smi = subprocess.check_output(['hl-smi', "-L"])
|
||||
num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
|
||||
try:
|
||||
hl_smi = subprocess.check_output(['hl-smi', "-L"])
|
||||
num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
|
||||
except FileNotFoundError:
|
||||
sim_list = subprocess.check_output(['ls', '-1', '/dev/accel'])
|
||||
num_accelerators = re.findall(r"accel(\d+)", sim_list.decode())
|
||||
num_accelerators = sorted(num_accelerators, key=int)
|
||||
os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators)
|
||||
elif get_accelerator().device_name() == 'npu':
|
||||
|
@ -18,8 +18,8 @@ if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]:
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
x = first.detach().numpy()
|
||||
y = second.detach().numpy()
|
||||
x = first.detach().float().numpy()
|
||||
y = second.detach().float().numpy()
|
||||
if verbose:
|
||||
print("x = {}".format(x.flatten()))
|
||||
print("y = {}".format(y.flatten()))
|
||||
|
@ -21,8 +21,8 @@ pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower()
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
x = first.detach().numpy()
|
||||
y = second.detach().numpy()
|
||||
x = first.detach().float().numpy()
|
||||
y = second.detach().float().numpy()
|
||||
print("ATOL", atol)
|
||||
if verbose:
|
||||
print("x = {}".format(x.flatten()))
|
||||
@ -43,7 +43,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2):
|
||||
check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"])
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"])
|
||||
@pytest.mark.parametrize('model_size',
|
||||
[
|
||||
(64),
|
||||
@ -65,6 +65,9 @@ class TestCPUAdam(DistributedTest):
|
||||
@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME],
|
||||
reason="FusedAdam is not compatible")
|
||||
def test_fused_adam_equal(self, dtype, model_size):
|
||||
if dtype not in get_accelerator().supported_dtypes():
|
||||
pytest.skip(f"dtype {dtype} not supported in current accelerator")
|
||||
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
|
||||
@ -91,6 +94,8 @@ class TestCPUAdam(DistributedTest):
|
||||
|
||||
def test_torch_adamw_equal(self, dtype, model_size):
|
||||
if get_accelerator().is_available():
|
||||
if dtype == torch.half:
|
||||
pytest.skip("torch.optim.AdamW with half precision inf/nan output.")
|
||||
if ("amd" in pytest.cpu_vendor) and (dtype == torch.half):
|
||||
pytest.skip("cpu-adam with half precision not supported on AMD CPUs")
|
||||
ref_param_device = get_accelerator().device_name()
|
||||
@ -99,20 +104,20 @@ class TestCPUAdam(DistributedTest):
|
||||
pytest.skip("torch.optim.AdamW with half precision only supported in CUDA environments.")
|
||||
ref_param_device = 'cpu'
|
||||
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
cpu_param = torch.nn.Parameter(cpu_data)
|
||||
ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device))
|
||||
cpu_data = torch.randn(model_size, device='cpu').to(dtype)
|
||||
cpu_param = torch.nn.Parameter(cpu_data)
|
||||
ref_param = torch.nn.Parameter(cpu_data.to(ref_param_device))
|
||||
|
||||
cpu_optimizer = DeepSpeedCPUAdam([cpu_param])
|
||||
ref_optimizer = torch.optim.AdamW([ref_param])
|
||||
cpu_optimizer = DeepSpeedCPUAdam([cpu_param])
|
||||
ref_optimizer = torch.optim.AdamW([ref_param])
|
||||
|
||||
_compare_optimizers(model_size=model_size,
|
||||
param1=cpu_param,
|
||||
optimizer1=cpu_optimizer,
|
||||
param2=ref_param,
|
||||
optimizer2=ref_optimizer)
|
||||
_compare_optimizers(model_size=model_size,
|
||||
param1=cpu_param,
|
||||
optimizer1=cpu_optimizer,
|
||||
param2=ref_param,
|
||||
optimizer2=ref_optimizer)
|
||||
|
||||
|
||||
class TestCPUAdamGPUError(DistributedTest):
|
||||
|
@ -22,8 +22,8 @@ pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower()
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
x = first.detach().numpy()
|
||||
y = second.detach().numpy()
|
||||
x = first.detach().float().numpy()
|
||||
y = second.detach().float().numpy()
|
||||
print("ATOL", atol)
|
||||
if verbose:
|
||||
print("x = {}".format(x.flatten()))
|
||||
@ -32,7 +32,7 @@ def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"])
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"])
|
||||
@pytest.mark.parametrize('model_size', [8, 16])
|
||||
class TestHybridAdam(DistributedTest):
|
||||
world_size = 1
|
||||
|
@ -18,8 +18,8 @@ pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower()
|
||||
|
||||
|
||||
def check_equal(first, second, atol=1e-2, verbose=False):
|
||||
x = first.detach().numpy()
|
||||
y = second.detach().numpy()
|
||||
x = first.detach().float().numpy()
|
||||
y = second.detach().float().numpy()
|
||||
print("ATOL", atol)
|
||||
if verbose:
|
||||
print("x = {}".format(x.flatten()))
|
||||
@ -40,7 +40,7 @@ def _compare_optimizers(model_size, param1, optimizer1, param2, optimizer2):
|
||||
check_equal(param1.float().norm(), param2.float().cpu().norm(), atol=tolerance, verbose=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"])
|
||||
@pytest.mark.parametrize('dtype', [torch.half, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"])
|
||||
@pytest.mark.parametrize('model_size',
|
||||
[
|
||||
(64),
|
||||
|
Reference in New Issue
Block a user