mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix Type Name Inconsistency & Typo in cpu_adam (#6732)
There is a typing error & inconsistency in cpu-adam code, while not affecting functionality, impacts code readability. Specifically, the type name `ds_params_percision_t` contains a typo ('percision'), whereas the related type name `ds_state_precision_t` is spelled correctly. I think it is beneficial to fix this typo&inconsistency to improve code readability, maintainability and further development. I have tested the corrected version of cpu_adam, and it compiles and runs successfully. Compilation Log: <img width="2560" alt="image" src="https://github.com/user-attachments/assets/b7bc307d-9c9d-4ab7-8671-34e565903ca5"> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -17,9 +17,9 @@ static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
|||||||
|
|
||||||
// C++ interface
|
// C++ interface
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
|
void Adagrad_Optimizer::Step_1(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -56,9 +56,9 @@ void Adagrad_Optimizer::Step_1(ds_params_percision_t* _params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adagrad_Optimizer::Step_4(ds_params_percision_t* _params,
|
void Adagrad_Optimizer::Step_4(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -104,9 +104,9 @@ int create_adagrad_optimizer(int optimizer_id,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
|
void Adagrad_Optimizer::Step_8(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -121,15 +121,15 @@ void Adagrad_Optimizer::Step_8(ds_params_percision_t* _params,
|
|||||||
(_param_size - rounded_size));
|
(_param_size - rounded_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
|
void step_invoker(std::shared_ptr<Adagrad_Optimizer> opt,
|
||||||
void* _params,
|
void* _params,
|
||||||
void* grads,
|
void* grads,
|
||||||
void* _exp_avg_sq,
|
void* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
opt->Step_8((ds_params_percision_t*)(_params),
|
opt->Step_8((ds_params_precision_t*)(_params),
|
||||||
(ds_params_percision_t*)(grads),
|
(ds_params_precision_t*)(grads),
|
||||||
(ds_state_precision_t*)(_exp_avg_sq),
|
(ds_state_precision_t*)(_exp_avg_sq),
|
||||||
_param_size);
|
_param_size);
|
||||||
}
|
}
|
||||||
@ -139,12 +139,12 @@ std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
|||||||
invokers;
|
invokers;
|
||||||
|
|
||||||
// Fill map with template functions for each type
|
// Fill map with template functions for each type
|
||||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
template <class ds_params_precision_t, class ds_state_precision_t>
|
||||||
void create_invoker()
|
void create_invoker()
|
||||||
{
|
{
|
||||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_precision_t>(),
|
||||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
step_invoker<ds_params_precision_t, ds_state_precision_t>;
|
||||||
}
|
}
|
||||||
struct InvokerInitializer {
|
struct InvokerInitializer {
|
||||||
InvokerInitializer()
|
InvokerInitializer()
|
||||||
|
@ -18,9 +18,9 @@ static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
|||||||
|
|
||||||
// C++ interface
|
// C++ interface
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adam_Optimizer::Step_1(ds_params_percision_t* _params,
|
void Adam_Optimizer::Step_1(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
@ -67,9 +67,9 @@ void Adam_Optimizer::Step_1(ds_params_percision_t* _params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adam_Optimizer::Step_4(ds_params_percision_t* _params,
|
void Adam_Optimizer::Step_4(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
@ -126,9 +126,9 @@ int create_adam_optimizer(int optimizer_id,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adam_Optimizer::Step_8(ds_params_percision_t* _params,
|
void Adam_Optimizer::Step_8(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
@ -145,7 +145,7 @@ void Adam_Optimizer::Step_8(ds_params_percision_t* _params,
|
|||||||
(_param_size - rounded_size));
|
(_param_size - rounded_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void step_invoker(std::shared_ptr<Adam_Optimizer> opt,
|
void step_invoker(std::shared_ptr<Adam_Optimizer> opt,
|
||||||
void* _params,
|
void* _params,
|
||||||
void* grads,
|
void* grads,
|
||||||
@ -153,8 +153,8 @@ void step_invoker(std::shared_ptr<Adam_Optimizer> opt,
|
|||||||
void* _exp_avg_sq,
|
void* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
opt->Step_8((ds_params_percision_t*)(_params),
|
opt->Step_8((ds_params_precision_t*)(_params),
|
||||||
(ds_params_percision_t*)(grads),
|
(ds_params_precision_t*)(grads),
|
||||||
(ds_state_precision_t*)(_exp_avg),
|
(ds_state_precision_t*)(_exp_avg),
|
||||||
(ds_state_precision_t*)(_exp_avg_sq),
|
(ds_state_precision_t*)(_exp_avg_sq),
|
||||||
_param_size);
|
_param_size);
|
||||||
@ -165,12 +165,12 @@ std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
|||||||
invokers;
|
invokers;
|
||||||
|
|
||||||
// Fill map with template functions for each type
|
// Fill map with template functions for each type
|
||||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
template <class ds_params_precision_t, class ds_state_precision_t>
|
||||||
void create_invoker()
|
void create_invoker()
|
||||||
{
|
{
|
||||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_precision_t>(),
|
||||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
step_invoker<ds_params_precision_t, ds_state_precision_t>;
|
||||||
}
|
}
|
||||||
struct InvokerInitializer {
|
struct InvokerInitializer {
|
||||||
InvokerInitializer()
|
InvokerInitializer()
|
||||||
|
@ -14,9 +14,9 @@
|
|||||||
#include "simd.h"
|
#include "simd.h"
|
||||||
|
|
||||||
#define STEP(SPAN) \
|
#define STEP(SPAN) \
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
template <typename ds_params_precision_t, typename ds_state_precision_t> \
|
||||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
void Step_##SPAN(ds_params_precision_t* _params, \
|
||||||
ds_params_percision_t* grads, \
|
ds_params_precision_t* grads, \
|
||||||
ds_state_precision_t* _exp_avg_sq, \
|
ds_state_precision_t* _exp_avg_sq, \
|
||||||
size_t _param_size);
|
size_t _param_size);
|
||||||
|
|
||||||
@ -28,10 +28,10 @@ public:
|
|||||||
}
|
}
|
||||||
~Adagrad_Optimizer() {}
|
~Adagrad_Optimizer() {}
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Step_AVX(size_t* rounded_size,
|
void Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t param_size);
|
size_t param_size);
|
||||||
#endif
|
#endif
|
||||||
@ -61,15 +61,15 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
#if !defined(__AVX512__)
|
#if !defined(__AVX512__)
|
||||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
if (std::is_same_v<ds_params_precision_t, c10::BFloat16> ||
|
||||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,9 @@
|
|||||||
#include "simd.h"
|
#include "simd.h"
|
||||||
|
|
||||||
#define STEP(SPAN) \
|
#define STEP(SPAN) \
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
template <typename ds_params_precision_t, typename ds_state_precision_t> \
|
||||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
void Step_##SPAN(ds_params_precision_t* _params, \
|
||||||
ds_params_percision_t* grads, \
|
ds_params_precision_t* grads, \
|
||||||
ds_state_precision_t* _exp_avg, \
|
ds_state_precision_t* _exp_avg, \
|
||||||
ds_state_precision_t* _exp_avg_sq, \
|
ds_state_precision_t* _exp_avg_sq, \
|
||||||
size_t _param_size);
|
size_t _param_size);
|
||||||
@ -43,10 +43,10 @@ public:
|
|||||||
~Adam_Optimizer() {}
|
~Adam_Optimizer() {}
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Step_AVX(size_t* rounded_size,
|
void Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t param_size);
|
size_t param_size);
|
||||||
@ -106,16 +106,16 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
void Adam_Optimizer::Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
ds_state_precision_t* _exp_avg_sq,
|
ds_state_precision_t* _exp_avg_sq,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
#if !defined(__AVX512__)
|
#if !defined(__AVX512__)
|
||||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
if (std::is_same_v<ds_params_precision_t, c10::BFloat16> ||
|
||||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -14,9 +14,9 @@
|
|||||||
#include "simd.h"
|
#include "simd.h"
|
||||||
|
|
||||||
#define STEP(SPAN) \
|
#define STEP(SPAN) \
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t> \
|
template <typename ds_params_precision_t, typename ds_state_precision_t> \
|
||||||
void Step_##SPAN(ds_params_percision_t* _params, \
|
void Step_##SPAN(ds_params_precision_t* _params, \
|
||||||
ds_params_percision_t* grads, \
|
ds_params_precision_t* grads, \
|
||||||
ds_state_precision_t* _exp_avg, \
|
ds_state_precision_t* _exp_avg, \
|
||||||
size_t _param_size);
|
size_t _param_size);
|
||||||
|
|
||||||
@ -32,10 +32,10 @@ public:
|
|||||||
~Lion_Optimizer() {}
|
~Lion_Optimizer() {}
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Step_AVX(size_t* rounded_size,
|
void Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
size_t param_size);
|
size_t param_size);
|
||||||
#endif
|
#endif
|
||||||
@ -67,15 +67,15 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
#if defined(__AVX512__) or defined(__AVX256__)
|
#if defined(__AVX512__) or defined(__AVX256__)
|
||||||
template <int span, typename ds_params_percision_t, typename ds_state_precision_t>
|
template <int span, typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
void Lion_Optimizer::Step_AVX(size_t* rounded_size,
|
||||||
ds_params_percision_t* _params,
|
ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
#if !defined(__AVX512__)
|
#if !defined(__AVX512__)
|
||||||
if (std::is_same_v<ds_params_percision_t, c10::BFloat16> ||
|
if (std::is_same_v<ds_params_precision_t, c10::BFloat16> ||
|
||||||
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
std::is_same_v<ds_state_precision_t, c10::BFloat16>) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -19,9 +19,9 @@ static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
|
|||||||
|
|
||||||
// C++ interface
|
// C++ interface
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Lion_Optimizer::Step_1(ds_params_percision_t* _params,
|
void Lion_Optimizer::Step_1(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -64,9 +64,9 @@ void Lion_Optimizer::Step_1(ds_params_percision_t* _params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Lion_Optimizer::Step_4(ds_params_percision_t* _params,
|
void Lion_Optimizer::Step_4(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -117,9 +117,9 @@ int create_lion_optimizer(int optimizer_id,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void Lion_Optimizer::Step_8(ds_params_percision_t* _params,
|
void Lion_Optimizer::Step_8(ds_params_precision_t* _params,
|
||||||
ds_params_percision_t* grads,
|
ds_params_precision_t* grads,
|
||||||
ds_state_precision_t* _exp_avg,
|
ds_state_precision_t* _exp_avg,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
@ -134,15 +134,15 @@ void Lion_Optimizer::Step_8(ds_params_percision_t* _params,
|
|||||||
(_param_size - rounded_size));
|
(_param_size - rounded_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ds_params_percision_t, typename ds_state_precision_t>
|
template <typename ds_params_precision_t, typename ds_state_precision_t>
|
||||||
void step_invoker(std::shared_ptr<Lion_Optimizer> opt,
|
void step_invoker(std::shared_ptr<Lion_Optimizer> opt,
|
||||||
void* _params,
|
void* _params,
|
||||||
void* grads,
|
void* grads,
|
||||||
void* _exp_avg,
|
void* _exp_avg,
|
||||||
size_t _param_size)
|
size_t _param_size)
|
||||||
{
|
{
|
||||||
opt->Step_8((ds_params_percision_t*)(_params),
|
opt->Step_8((ds_params_precision_t*)(_params),
|
||||||
(ds_params_percision_t*)(grads),
|
(ds_params_precision_t*)(grads),
|
||||||
(ds_state_precision_t*)(_exp_avg),
|
(ds_state_precision_t*)(_exp_avg),
|
||||||
_param_size);
|
_param_size);
|
||||||
}
|
}
|
||||||
@ -152,12 +152,12 @@ std::map<std::tuple<c10::ScalarType, c10::ScalarType>,
|
|||||||
invokers;
|
invokers;
|
||||||
|
|
||||||
// Fill map with template functions for each type
|
// Fill map with template functions for each type
|
||||||
template <class ds_params_percision_t, class ds_state_precision_t>
|
template <class ds_params_precision_t, class ds_state_precision_t>
|
||||||
void create_invoker()
|
void create_invoker()
|
||||||
{
|
{
|
||||||
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_percision_t>(),
|
invokers[std::tuple(c10::CppTypeToScalarType<ds_params_precision_t>(),
|
||||||
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
c10::CppTypeToScalarType<ds_state_precision_t>())] =
|
||||||
step_invoker<ds_params_percision_t, ds_state_precision_t>;
|
step_invoker<ds_params_precision_t, ds_state_precision_t>;
|
||||||
}
|
}
|
||||||
struct InvokerInitializer {
|
struct InvokerInitializer {
|
||||||
InvokerInitializer()
|
InvokerInitializer()
|
||||||
|
Reference in New Issue
Block a user