mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Hi. Please review the following changes I added support for BF16 to cpu adam. BF16, FP16 and float are supported at compilation time. the correct template is called at runtime according to input params dtype. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
14 lines
409 B
C++
14 lines
409 B
C++
// Copyright (c) Microsoft Corporation.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// DeepSpeed Team
|
|
|
|
#include "cpu_adam.h"
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
{
|
|
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
|
|
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
|
|
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
|
|
}
|