Fix the bugs about operator registration by PyTorch Dispatcher (#2786)

**Background:**

There are two principles about operator registration in PyTorch
- The same namespace can be only registered once by `TORCH_LIBRARY`
- The operator signatures can be only registered once by `def`

Considering that all custom operators defined in the current repo are
only used by Ascend, instead of defining a common operator schema by
vLLM, all accelerators then follow this operator schema and complete the
implementation based on their respective hardware, which is conducive to
functional abstraction.

Therefore, we can rename the operator registration namespace to an
Ascend-specific namespace(**_C_ascend**).

Related ISSUE: https://github.com/vllm-project/vllm-ascend/issues/2742


- vLLM version: main
- vLLM main:
f592b3174b

Signed-off-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
Jiawei Li
2025-09-13 11:58:52 +08:00
committed by GitHub
parent 138e932630
commit e57cca971c
16 changed files with 97 additions and 65 deletions

View File

@ -112,7 +112,7 @@ def test_get_masked_input_and_mask(
# Define custom function
def custom_fn():
return torch.ops._C.get_masked_input_and_mask(
return torch.ops._C_ascend.get_masked_input_and_mask(
input_tensor,
test_case["org_start"],
test_case["org_end"],

View File

@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(_C, ops)
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
{
// vLLM-Ascend custom ops
ops.def("weak_ref_tensor(Tensor input) -> Tensor");

View File

@ -88,7 +88,7 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
namespace {
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
// the custom kernel been captured into aclgraph
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
// Rotary embedding meta implementation
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
// Masked input and mask meta implementation

View File

@ -33,8 +33,8 @@ def test_bgmv_expand():
y_npu = y.npu()
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
128)
y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu,
y_npu, 0, 128)
# Compare the results.
torch.testing.assert_close(y_out_npu.cpu(),

View File

@ -33,7 +33,7 @@ def test_bgmv_shrink():
y_npu = y.npu()
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
# Compare the results.
torch.testing.assert_close(y_npu.cpu(),

View File

@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
)
ref_query, ref_key = rope.forward_native(positions, query, key)
query, key = torch.ops._C.rotary_embedding(
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,
@ -239,7 +239,7 @@ class ModelwithRotaryEmbedding(nn.Module):
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(3, dim=-1)
query, key = torch.ops._C.rotary_embedding(
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
q,
k,
@ -299,7 +299,7 @@ def test_capture_rotary_embedding_in_aclgraph(
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
# string match
graph = str(gm.graph)
assert "_C.rotary_embedding" in graph
assert "_C_ascend.rotary_embedding" in graph
return gm
static_positions = torch.randint(0, max_position_embeddings,

View File

@ -72,7 +72,7 @@ def test_get_masked_input_and_mask(
# Get custom op result
print("input_tensor:", input_tensor)
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
input_tensor, test_case["org_start"], test_case["org_end"],
test_case["padding"], test_case["added_start"], test_case["added_end"])

View File

@ -94,7 +94,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('torch.ops._C')
@patch('torch.ops._C_ascend')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)

View File

@ -104,7 +104,7 @@ class TestRopeForwardOot(TestBase):
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C')
@patch('torch.ops._C_ascend')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',

View File

@ -15,7 +15,8 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
from ..utils import weak_ref_tensors
@dataclasses.dataclass

View File

@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
return torch.ops._C.bgmv_shrink(
return torch.ops._C_ascend.bgmv_shrink(
inputs,
lora_a_weights,
lora_indices_tensor,
@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(
return torch.ops._C_ascend.bgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
@ -52,7 +52,7 @@ def bgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
@ -69,7 +69,7 @@ def sgmv_shrink(
token_nums: int,
scaling: float,
):
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, scaling)
@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
return torch.ops._C.sgmv_expand(
return torch.ops._C_ascend.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, slice_offset, slice_size)
output_tensor, slice_offset,
slice_size)

View File

@ -23,7 +23,7 @@ from torch.library import Library
# Do NOT perform any real computation or allocate device memory.
#
# 2. Register your meta function using `register_meta_if_necessary`, providing:
# - The namespace (usually "_C" for custom ops)
# - The namespace (usually "_C_ascend" for custom ops)
# - The operator name (as registered in C++)
# - The Python meta function
# - (Optional) The overload name, if your op has overloads
@ -39,7 +39,7 @@ from torch.library import Library
#
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
lib = Library("_C", "IMPL")
lib = Library("_C_ascend", "IMPL")
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
return y_out
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C", "get_masked_input_and_mask",
register_meta_if_necessary("_C_ascend", "rotary_embedding",
rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)

View File

@ -35,19 +35,20 @@ class dummyFusionOp:
def register_dummy_fusion_op() -> None:
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
name="fused_add_rms_norm")
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")

View File

@ -49,7 +49,7 @@ def _rope_forward_oot(
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, is_neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,

View File

@ -62,7 +62,7 @@ def rope_forward_oot(
# adopt custom kernel path for rotary_embedding
if custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,

View File

@ -24,7 +24,7 @@ import os
from contextlib import contextmanager
from enum import Enum
from threading import Lock
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import torch
import torch_npu # noqa: F401 # noqa: F401
@ -589,3 +589,31 @@ def dense_optim_enable() -> bool:
def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict())
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if isinstance(tensor, torch.Tensor):
return torch.ops._C_ascend.weak_ref_tensor(tensor)
else:
return tensor
def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")