Files
DeepSpeed/csrc/adam/cpu_adam.cpp
Liran Bachar 69af361167 CPUAdam fp16 and bf16 support (#5409)
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported
at compilation time. the correct template is called at runtime according
to input params dtype.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-05-20 12:50:20 +00:00

14 lines
409 B
C++

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "cpu_adam.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
}