[core] Support capture custom ops into aclgraph (#2113)

### What this PR does / why we need it?
Thanks to the PR https://github.com/vllm-project/vllm-ascend/pull/426
make vllm-ascend support the aclgraph inference to reduce the host
overhead. However, the capability of aclgraph strongly relies on the
functionality provided by `torch.compile`, which is the key feature
supported in torch 2.x . Therefore, capture custom op into aclgraph is
only possible when it can be recognize and captured by `torch.compile`.

In this PR, we register the meta implementation of current custom ops to
enable the fx graph capture. And by doing that, insert those custom ops
into aclgraph become a natural thing to the ascend runtime.

### Does this PR introduce _any_ user-facing change?
No user face change.

### How was this patch tested?
Tested in unittest, we will integrate the `rotary_embedding` op into a
small custom model and use `torch.compile` and aclgraph to capture and
replay it to verify its functionality.

- vLLM version: v0.10.0
- vLLM main:
1b99028069

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-08-11 15:59:42 +08:00
committed by GitHub
parent 1ab15414bb
commit c0f0b70813
6 changed files with 332 additions and 13 deletions

View File

@ -27,6 +27,17 @@
namespace vllm_ascend { namespace vllm_ascend {
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}
std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{ {

View File

@ -0,0 +1,86 @@
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/npu/Module.h>
#include "utils.h"
/*
* How to write a meta implementation for a custom operator (meta kernel):
*
* Meta implementations are used for shape and dtype inference, tracing, and export.
* They do NOT perform any real computation or allocate device memory.
* Instead, they return empty tensors with the correct shapes, dtypes, and device types.
*
* Steps to write a meta implementation:
* 1. The function signature should match the operator's schema, but only use the arguments
* necessary to infer output shapes and dtypes.
* 2. Use input tensor shapes, dtypes, and any relevant arguments to compute the output shapes.
* 3. Return empty tensors (e.g., at::empty_symint, at::empty_like) with the correct shape and dtype.
* 4. Do NOT perform any real computation or data movement.
* 5. Register the meta implementation with the "Meta" dispatch key using TORCH_LIBRARY_IMPL or similar.
*
* Example:
* std::tuple<at::Tensor, at::Tensor> my_op_meta(
* at::Tensor &input, int64_t some_param) {
* // Infer output shape based on input and parameters
* auto out_shape = ...;
* at::Tensor out = at::empty_symint(out_shape, input.options());
* // Return empty tensor(s) with correct shape/dtype
* return {out, ...};
* }
*
* See below for real examples.
*/
namespace vllm_ascend {
namespace meta {
std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
at::Tensor &positions,
at::Tensor &query,
at::Tensor &key,
int64_t head_size,
at::Tensor &cos_sin_cache,
bool is_neox) {
auto num_tokens = positions.sym_numel();
auto query_hidden_size = query.sym_numel() / num_tokens;
auto key_hidden_size = key.sym_numel() / num_tokens;
auto num_heads = query_hidden_size / head_size;
auto num_kv_heads = key_hidden_size / head_size;
at::Tensor query_dst = at::empty_symint({num_tokens, num_heads, head_size}, query.options());
at::Tensor key_dst = at::empty_symint({num_tokens, num_kv_heads, head_size}, key.options());
return {query_dst, key_dst};
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
at::Tensor &input,
const int64_t org_vocab_start_index,
const int64_t org_vocab_end_index,
const int64_t num_org_vocab_padding,
const int64_t added_vocab_start_index,
const int64_t added_vocab_end_index) {
at::Tensor masked_input = at::empty_like(input);
at::Tensor mask = at::empty_like(input, input.options().dtype(at::kBool));
return {masked_input, mask};
}
} // namespace meta
} // namespace vllm_ascend
namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
}
}

View File

@ -29,15 +29,3 @@
} }
namespace vllm_ascend {
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}
} // namespace vllm_ascend

View File

@ -17,11 +17,12 @@ enable_custom_op()
# Only Neox style true scenario is supported for now # Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True] IS_NEOX_STYLE = [True]
DTYPES = [torch.half] DTYPES = [torch.half]
HEAD_SIZES = [64, 96, 128, 256] HEAD_SIZES = [64, 64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing SEQ_LENS = [11, 4096] # Arbitrary values for testing
NUM_TOKENS = [10, 21]
SEEDS = [0] SEEDS = [0]
DEVICES = [f"npu:{0}"] DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops # Set tolerance to 1 for quant ops
@ -198,3 +199,146 @@ def test_rotary_embedding_quant_with_leading_dim(
ref_key, ref_key,
atol=DEFAULT_ATOL, atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL) rtol=DEFAULT_RTOL)
class ModelwithRotaryEmbedding(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3)
self.rope = RotaryEmbedding(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
self.o_proj = nn.Linear(num_heads * head_size, hidden_size)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C.rotary_embedding(
positions,
q,
k,
self.rope.head_size,
self.rope.cos_sin_cache,
self.rope.is_neox_style,
)
query = query.view(q.shape)
key = key.view(k.shape)
o = self.o_proj(query)
return o
# The first graph seems will have some accuracy issue when directly run pytest on the ops folder,
# add a warmup graph replay for workaround
ACL_GRPAH_FIRST_RUN = True
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_capture_rotary_embedding_in_aclgraph(
is_neox_style: bool,
num_tokens: int,
num_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
max_position_embeddings: int = 8192,
base: int = 10000,
):
"""Test if the rotary embedding can be captured in aclgraph."""
torch.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
model = ModelwithRotaryEmbedding(
hidden_size=num_heads * head_size,
num_heads=num_heads,
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input):
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C.rotary_embedding" in graph
return gm
static_positions = torch.randint(0, max_position_embeddings,
(num_tokens, ))
static_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
compiled_model = torch.compile(model, backend=custom_op_checking_backend)
stream = torch.npu.Stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
# warmup the fx graph before capture
for i in range(3):
static_output = compiled_model(static_positions,
static_hidden_states,
offsets=None)
stream.wait_stream(torch.npu.current_stream())
aclgraph = torch.npu.NPUGraph()
with torch.npu.graph(aclgraph):
# Capture the model in aclgraph.
static_output = compiled_model(static_positions, static_hidden_states)
# Capture the model in aclgraph.
random_filled_positions = torch.randint(0,
max_position_embeddings,
(num_tokens, ),
device="npu")
random_filled_hidden_states = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device="npu")
static_positions.copy_(random_filled_positions)
static_hidden_states.copy_(random_filled_hidden_states)
aclgraph.replay()
global ACL_GRPAH_FIRST_RUN
if ACL_GRPAH_FIRST_RUN:
ACL_GRPAH_FIRST_RUN = False
return
output_reference = model(static_positions, static_hidden_states)
torch.testing.assert_close(static_output,
output_reference,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

@ -0,0 +1,86 @@
import torch
from torch.library import Library
# This file provides a template and registration utilities for writing "meta" implementations
# of custom operators in Python for the vllm_ascend project.
#
# We offer two ways to implement meta implementations for custom ops:
# 1. Python meta implementation (as shown in this file): Write a Python function that
# takes the same arguments as your operator and returns empty tensors with the correct
# shapes and dtypes. This is useful for rapid prototyping and for ops that are only
# used in Python.
# 2. C++ meta implementation: You can also implement the meta function in C++ for better
# performance or to match the C++ op logic more closely. See `torch_binding_meta.cpp`
# for examples of C++ meta implementations and how to register them.
#
# Both approaches enable tracing, export, and shape inference in PyTorch and vLLM, which
# is essential for supporting `torch.compile` and aclgraph.
# How to add a new meta implementation in Python:
# -------------------------------------
# 1. Write a Python function that takes the same arguments as your operator, and returns
# empty tensors (using torch.empty_like, torch.empty, etc.) with the correct shapes and dtypes.
# Do NOT perform any real computation or allocate device memory.
#
# 2. Register your meta function using `register_meta_if_necessary`, providing:
# - The namespace (usually "_C" for custom ops)
# - The operator name (as registered in C++)
# - The Python meta function
# - (Optional) The overload name, if your op has overloads
#
# 3. The registration utility will check if a meta implementation already exists for your op,
# and only register if necessary. This avoids duplicate registrations.
#
# 4. Example meta implementations are provided below for rotary_embedding and get_masked_input_and_mask.
#
# 5. When developing new custom ops, always provide a meta implementation to enable tracing,
# export, and shape inference in PyTorch and vLLM to enable the capture of `torch.compile`
# and aclgraph.
#
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
lib = Library("_C", "IMPL")
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
if overload != "":
op_name = op_name + "." + overload
schema_to_find = ns + "::" + op_name
meta_impl_list = torch._C._dispatch_get_registrations_for_dispatch_key(
"Meta")
if schema_to_find in meta_impl_list:
return
lib.impl(op_name, fn, "Meta")
def rotary_embedding_meta(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool):
num_tokens = positions.numel()
query_hidden_size = query.numel() // num_tokens
key_hidden_size = key.numel() // num_tokens
num_heads = query_hidden_size // head_size
num_kv_heads = key_hidden_size // head_size
query_dst = torch.empty_like(query).view(num_tokens, num_heads, head_size)
key_dst = torch.empty_like(key).view(num_tokens, num_kv_heads, head_size)
return query_dst, key_dst
def get_masked_input_and_mask_meta(input: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int):
masked_input = torch.empty_like(input)
mask = torch.empty_like(input).to(torch.bool)
return masked_input, mask
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)

View File

@ -214,8 +214,12 @@ def enable_custom_op():
if _CUSTOM_OP_ENABLED is not None: if _CUSTOM_OP_ENABLED is not None:
return _CUSTOM_OP_ENABLED return _CUSTOM_OP_ENABLED
try: try:
# isort: off
# register custom ops into torch_library here # register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
# register the meta implementation for custom kernel if necessary
import vllm_ascend.meta_registration # type: ignore # noqa: F401
# isort: on
_CUSTOM_OP_ENABLED = True _CUSTOM_OP_ENABLED = True
except ImportError: except ImportError:
_CUSTOM_OP_ENABLED = False _CUSTOM_OP_ENABLED = False