[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2025-06-20 09:44:56 -05:00
committed by GitHub
parent f1e840e842
commit 7e8977fcd4
7 changed files with 120 additions and 6 deletions

View File

@ -10,5 +10,7 @@ setup(
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
],
"vllm.general_plugins":
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
})

View File

@ -6,3 +6,7 @@ from typing import Optional
def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
def register_ops():
import vllm_add_dummy_platform.dummy_custom_ops # noqa

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.flash_attn import FlashAttentionBackend
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
class DummyAttentionBackend(FlashAttentionBackend):
class DummyAttentionBackend(PlaceholderAttentionBackend):
@staticmethod
def get_name() -> str:

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
# Register CustomRotaryEmbedding to CustomOP.
@RotaryEmbedding.register_oot
class DummyRotaryEmbedding(RotaryEmbedding):
"""Original rotary positional embedding."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.addition_config = True
def forward_oot(self, *args,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
return super().forward_oot(*args, **kwargs)

View File

@ -1,12 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
from vllm import envs
class DummyPlatform(CudaPlatform):
class DummyPlatform(Platform):
_enum = PlatformEnum.OOT
device_name = "DummyDevice"
device_type: str = "privateuseone"
dispatch_key: str = "PrivateUse1"
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
compilation_config = vllm_config.compilation_config
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -5,6 +5,7 @@ import pytest
import torch
from vllm.attention.selector import get_attn_backend
from vllm.plugins import load_general_plugins
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
assert backend.get_name() == "Dummy_Backend"
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
# simulate workload by running an example
load_general_plugins()
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
"possibly because the custom op is not registered correctly.")
assert hasattr(layer, "addition_config"), (
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
"which is set by the custom op.")

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch.nn as nn
from vllm.config import get_current_vllm_config
@ -16,6 +18,24 @@ class CustomOp(nn.Module):
Dispatches the forward method to the appropriate backend.
"""
def __new__(cls, *args, **kwargs):
try:
op_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@CustomOp.register, or it's the CustomOp base class itself."
) from None
if op_name not in cls.op_registry_oot:
op_cls_to_instantiate = cls
else:
op_cls_to_instantiate = cls.op_registry_oot[op_name]
logger.debug("Instantiating custom op: %s using %s", op_name,
str(op_cls_to_instantiate))
return super().__new__(op_cls_to_instantiate)
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
@ -138,6 +158,7 @@ class CustomOp(nn.Module):
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type['CustomOp']] = {}
op_registry_oot: dict[str, type['CustomOp']] = {}
# Decorator to register custom ops.
@classmethod
@ -150,3 +171,38 @@ class CustomOp(nn.Module):
return op_cls
return decorator
# Decorator to register out-of-tree(oot) custom ops.
# For OOT custom ops:
# if in-tree layer class is registered with an oot_custom_op layer,
# the oot_custom_op layer will be used instead.
# Example:
# - @UnquantizedFusedMoEMethod.register_oot
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
# or
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
@classmethod
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):
def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, \
f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls
return op_cls
if _decorated_op_cls is None:
# Called with parentheses: @CustomOP.register_oot()
# or @CustomOP.register_oot(name="...")
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return decorator
elif isinstance(_decorated_op_cls, type): # Check if it's a class
# Called without parentheses: @CustomOP.register_oot
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return decorator(_decorated_op_cls)
else:
# Handle other unexpected cases if necessary
raise TypeError("Decorator can only be applied to classes.")