Files
DeepSpeed/op_builder/cpu_adagrad.py
Liran Bachar 69af361167 CPUAdam fp16 and bf16 support (#5409)
Hi.
Please review the following changes
I added support for BF16 to cpu adam. BF16, FP16 and float are supported
at compilation time. the correct template is called at runtime according
to input params dtype.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-05-20 12:50:20 +00:00

28 lines
622 B
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import TorchCPUOpBuilder
class CPUAdagradBuilder(TorchCPUOpBuilder):
BUILD_VAR = "DS_BUILD_CPU_ADAGRAD"
NAME = "cpu_adagrad"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.adagrad.{self.NAME}_op'
def sources(self):
return ['csrc/adagrad/cpu_adagrad.cpp']
def libraries_args(self):
args = super().libraries_args()
return args
def include_paths(self):
return ['csrc/includes']