Files
DeepSpeed/op_builder/hpu/transformer_inference.py
Omar Elayan 645639bcf8 Rearrange inference OPS and stop using builder.load (#5490)
This PR mainly handles all places where InferenceBuilder is used to
access any op or a specific implementation for an op.
Instead an op is defined, and its proper implementation is picked inside
and the usage will be transparent to the user.
What was done in the PR:
1) Added missing ops (added a py file with fallback mechanism)
2) Added missing fallback implementations for existing ops
3) removed all usages for builder.load and replaced them with ops
instead.
4) added workspace op and inferenceContext which contains all workspace
related functions and inferenceContext is the python fallback of
inferenceContext in CUDA
5) a small change to softmax_context signature to fit the fallback
signature.

---------

Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2024-10-09 01:22:28 +00:00

40 lines
1.3 KiB
Python

# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
import importlib
# DeepSpeed Team
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder
class InferenceBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
NAME = "transformer_inference"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=self.NAME)
def absolute_name(self):
return f"deepspeed.ops.transformer.inference.{self.NAME}_op"
def sources(self):
return []
def load(self, verbose=True):
if self.name in __class__._loaded_ops:
return __class__._loaded_ops[self.name]
from deepspeed.git_version_info import installed_ops # noqa: F401
if installed_ops.get(self.name, False):
op_module = importlib.import_module(self.absolute_name())
__class__._loaded_ops[self.name] = op_module
return op_module