Files
DeepSpeed/op_builder/hpu/cpu_adam.py
Liran Bachar 69af361167 CPUAdam fp16 and bf16 support (#5409)
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported
at compilation time. the correct template is called at runtime according
to input params dtype.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-05-20 12:50:20 +00:00

29 lines
681 B
Python

# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import CPUOpBuilder
class CPUAdamBuilder(CPUOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_ADAM"
NAME = "cpu_adam"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adam.{self.NAME}_op'
def sources(self):
return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp']
def libraries_args(self):
args = super().libraries_args()
return args
def include_paths(self):
return ['csrc/includes']