mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
from .builder import TorchCPUOpBuilder
|
|
|
|
|
|
class CPULionBuilder(TorchCPUOpBuilder):
|
|
BUILD_VAR = "DS_BUILD_CPU_LION"
|
|
NAME = "cpu_lion"
|
|
|
|
def __init__(self):
|
|
super().__init__(name=self.NAME)
|
|
|
|
def absolute_name(self):
|
|
return f'deepspeed.ops.lion.{self.NAME}_op'
|
|
|
|
def sources(self):
|
|
if self.build_for_cpu:
|
|
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp']
|
|
|
|
return ['csrc/lion/cpu_lion.cpp', 'csrc/lion/cpu_lion_impl.cpp', 'csrc/common/custom_cuda_kernel.cu']
|
|
|
|
def libraries_args(self):
|
|
args = super().libraries_args()
|
|
if self.build_for_cpu:
|
|
return args
|
|
|
|
if not self.is_rocm_pytorch():
|
|
args += ['curand']
|
|
|
|
return args
|
|
|
|
def include_paths(self):
|
|
import torch
|
|
if self.build_for_cpu:
|
|
CUDA_INCLUDE = []
|
|
elif not self.is_rocm_pytorch():
|
|
CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
|
else:
|
|
CUDA_INCLUDE = [
|
|
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
|
|
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
|
|
os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
|
|
]
|
|
return ['csrc/includes'] + CUDA_INCLUDE
|