From e57cca971c4db11d8ce9da6b008bd655ada8c77e Mon Sep 17 00:00:00 2001 From: Jiawei Li Date: Sat, 13 Sep 2025 11:58:52 +0800 Subject: [PATCH] Fix the bugs about operator registration by PyTorch Dispatcher (#2786) **Background:** There are two principles about operator registration in PyTorch - The same namespace can be only registered once by `TORCH_LIBRARY` - The operator signatures can be only registered once by `def` Considering that all custom operators defined in the current repo are only used by Ascend, instead of defining a common operator schema by vLLM, all accelerators then follow this operator schema and complete the implementation based on their respective hardware, which is conducive to functional abstraction. Therefore, we can rename the operator registration namespace to an Ascend-specific namespace(**_C_ascend**). Related ISSUE: https://github.com/vllm-project/vllm-ascend/issues/2742 - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/f592b3174b39a7aeac52432d66d76e89ff0a80b4 Signed-off-by: FFFrog --- benchmarks/ops/ben_vocabparallelembedding.py | 2 +- csrc/torch_binding.cpp | 28 +++++++-------- csrc/torch_binding_meta.cpp | 8 ++--- tests/e2e/singlecard/ops/test_bgmv_expand.py | 4 +-- tests/e2e/singlecard/ops/test_bgmv_shrink.py | 2 +- .../singlecard/ops/test_rotary_embedding.py | 6 ++-- .../ops/test_vocabparallelembedding.py | 2 +- tests/ut/ops/test_rotary_embedding.py | 2 +- .../ops/test_torchair_rotary_embedding.py | 2 +- vllm_ascend/compilation/acl_graph.py | 13 +++---- vllm_ascend/lora/lora_ops.py | 25 +++++++------- vllm_ascend/meta_registration.py | 13 +++---- vllm_ascend/ops/__init__.py | 17 +++++----- vllm_ascend/ops/rotary_embedding.py | 2 +- .../torchair/ops/torchair_rotary_embedding.py | 2 +- vllm_ascend/utils.py | 34 +++++++++++++++++-- 16 files changed, 97 insertions(+), 65 deletions(-) diff --git a/benchmarks/ops/ben_vocabparallelembedding.py b/benchmarks/ops/ben_vocabparallelembedding.py index b3ef7ec50..5590c7337 100644 --- a/benchmarks/ops/ben_vocabparallelembedding.py +++ b/benchmarks/ops/ben_vocabparallelembedding.py @@ -112,7 +112,7 @@ def test_get_masked_input_and_mask( # Define custom function def custom_fn(): - return torch.ops._C.get_masked_input_and_mask( + return torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 1291a3921..5dd6988a9 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -141,7 +141,7 @@ std::tuple get_masked_input_and_mask( TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | - index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | Parameters: org_vocab_start_index //base embeddings start org_vocab_end_index //base embeddings end @@ -164,22 +164,22 @@ std::tuple get_masked_input_and_mask( // Create output tensors at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input).to(at::kBool); - + // Get data pointers void *input_ptr = input.data_ptr(); void *masked_input_ptr = masked_input.data_ptr(); void *mask_ptr = mask.data_ptr(); - + // Get current stream aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - + // Get scalar type at::ScalarType scalar_type = input.scalar_type(); - + // Create and configure OpCommand at_npu::native::OpCommand cmd; cmd.Name("get_masked_input_and_mask"); - cmd.SetCustomHandler([scalar_type, size, stream, + cmd.SetCustomHandler([scalar_type, size, stream, input_ptr, masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, @@ -193,7 +193,7 @@ std::tuple get_masked_input_and_mask( get_masked_input_and_mask_impl( stream, input_ptr, - masked_input_ptr, + masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, @@ -203,7 +203,7 @@ std::tuple get_masked_input_and_mask( size, loop_cnt, aiv_num); - + return 0; }); cmd.Run(); @@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_shrink"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, - seq_len_ptr, seq_len_size, y_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, + seq_len_ptr, seq_len_size, y_ptr, batch_size, input_hidden_token, lora_rank, scale_f]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, - y_ptr, batch_size, + y_ptr, batch_size, num_tokens_per_core, input_hidden_token, lora_rank, scale_f); return 0; }); @@ -367,7 +367,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_expand"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -375,7 +375,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim); return 0; }); @@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic } } // namespace vllm_ascend -TORCH_LIBRARY_EXPAND(_C, ops) +TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) { // vLLM-Ascend custom ops ops.def("weak_ref_tensor(Tensor input) -> Tensor"); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index d69254b49..4101ee71e 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -40,7 +40,7 @@ std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, - int64_t head_size, + int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); @@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ } // namespace vllm_ascend namespace { - // Register the meta implementations of the custom kernels for symbolic tracing, this will also + // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph - TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { + TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -99,4 +99,4 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); } -} \ No newline at end of file +} diff --git a/tests/e2e/singlecard/ops/test_bgmv_expand.py b/tests/e2e/singlecard/ops/test_bgmv_expand.py index 0aca9cadc..9d82ab8fe 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_expand.py +++ b/tests/e2e/singlecard/ops/test_bgmv_expand.py @@ -33,8 +33,8 @@ def test_bgmv_expand(): y_npu = y.npu() y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128) - y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0, - 128) + y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu, + y_npu, 0, 128) # Compare the results. torch.testing.assert_close(y_out_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_bgmv_shrink.py b/tests/e2e/singlecard/ops/test_bgmv_shrink.py index 99bb8e890..6cb8127ae 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_shrink.py +++ b/tests/e2e/singlecard/ops/test_bgmv_shrink.py @@ -33,7 +33,7 @@ def test_bgmv_shrink(): y_npu = y.npu() y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5) - torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) + torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) # Compare the results. torch.testing.assert_close(y_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_rotary_embedding.py b/tests/e2e/singlecard/ops/test_rotary_embedding.py index 6f513b219..27e9b3b9e 100644 --- a/tests/e2e/singlecard/ops/test_rotary_embedding.py +++ b/tests/e2e/singlecard/ops/test_rotary_embedding.py @@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim( ) ref_query, ref_key = rope.forward_native(positions, query, key) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, @@ -239,7 +239,7 @@ class ModelwithRotaryEmbedding(nn.Module): # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph qkv = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(3, dim=-1) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, q, k, @@ -299,7 +299,7 @@ def test_capture_rotary_embedding_in_aclgraph( # Validate if the rotary_embedding custom kernel is indeed inside the graph by # string match graph = str(gm.graph) - assert "_C.rotary_embedding" in graph + assert "_C_ascend.rotary_embedding" in graph return gm static_positions = torch.randint(0, max_position_embeddings, diff --git a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py index 54d112723..64b974dfd 100644 --- a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py +++ b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py @@ -72,7 +72,7 @@ def test_get_masked_input_and_mask( # Get custom op result print("input_tensor:", input_tensor) - custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask( + custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], test_case["padding"], test_case["added_start"], test_case["added_end"]) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index de6f4efff..21d95bb71 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -94,7 +94,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.mock_self.cos_sin_cache = self.cos_sin_cache self.mock_self.is_neox_style = self.is_neox_style - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=True) diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py index ce74deed0..4adb59887 100644 --- a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py +++ b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py @@ -104,7 +104,7 @@ class TestRopeForwardOot(TestBase): self.assertTrue(torch.equal(result_q, self.query)) self.assertTrue(torch.equal(result_k, self.key)) - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch( 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') @patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p', diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index f8dfc24e1..cc124480c 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -15,7 +15,8 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors + +from ..utils import weak_ref_tensors @dataclasses.dataclass @@ -35,10 +36,10 @@ class ACLGraphWrapper: The workflow of this wrapper in the aclgraph dispatching is as follows: 1. At initialization, a runtime mode is assigned to the wrapper (FULL or - PIECEWISE). - 2. At runtime, the wrapper receives a runtime_mode and a + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a batch_descriptor(key) from the forward context and blindly trust them - for aclgraph dispatching. + for aclgraph dispatching. 3. If runtime_mode is NONE or runtime_mode does not match the mode of the wrapper, just call the runnable directly. 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, @@ -47,9 +48,9 @@ class ACLGraphWrapper: Note: ACLGraphWrapper does not store persistent buffers or copy any runtime inputs into that buffers for replay. We assume implementing them - is done outside of the wrapper. That is because we do not make any + is done outside of the wrapper. That is because we do not make any assumption on the dynamic shape (batch size) of the runtime inputs, as a - trade-off for staying orthogonal to compilation logic. Nevertheless, + trade-off for staying orthogonal to compilation logic. Nevertheless, tracing and checking the input addresses to be consistent during replay is guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ diff --git a/vllm_ascend/lora/lora_ops.py b/vllm_ascend/lora/lora_ops.py index e8bf8ad97..58d0ea60e 100644 --- a/vllm_ascend/lora/lora_ops.py +++ b/vllm_ascend/lora/lora_ops.py @@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - return torch.ops._C.bgmv_shrink( + return torch.ops._C_ascend.bgmv_shrink( inputs, lora_a_weights, lora_indices_tensor, @@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - return torch.ops._C.bgmv_expand( + return torch.ops._C_ascend.bgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - return torch.ops._C.bgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, output_tensor, - slice_offset, slice_size) + return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) def sgmv_shrink( @@ -69,9 +69,9 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, scaling) + return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) def sgmv_expand(inputs: torch.Tensor, @@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor, max_seq_length: int, token_nums: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand( + return torch.ops._C_ascend.sgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, slice_offset, slice_size) + return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, + slice_size) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index 47c775887..9a58afd9c 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -23,7 +23,7 @@ from torch.library import Library # Do NOT perform any real computation or allocate device memory. # # 2. Register your meta function using `register_meta_if_necessary`, providing: -# - The namespace (usually "_C" for custom ops) +# - The namespace (usually "_C_ascend" for custom ops) # - The operator name (as registered in C++) # - The Python meta function # - (Optional) The overload name, if your op has overloads @@ -39,7 +39,7 @@ from torch.library import Library # # For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors -lib = Library("_C", "IMPL") +lib = Library("_C_ascend", "IMPL") def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): @@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, return y_out -register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta) -register_meta_if_necessary("_C", "get_masked_input_and_mask", +register_meta_if_necessary("_C_ascend", "rotary_embedding", + rotary_embedding_meta) +register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta) -register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta) -register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 5c8a79847..381c1b6df 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -35,19 +35,20 @@ class dummyFusionOp: def register_dummy_fusion_op() -> None: - torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm") - torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") - torch.ops._C.static_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( + name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( name="static_scaled_fp8_quant") - torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( name="dynamic_scaled_fp8_quant") - torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( name="dynamic_per_token_scaled_fp8_quant") - torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( name="rms_norm_static_fp8_quant") - torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( name="fused_add_rms_norm_static_fp8_quant") - torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( name="rms_norm_dynamic_per_token_quant") diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 4b76dceb0..9ddf2800a 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -49,7 +49,7 @@ def _rope_forward_oot( # adopt custom kernel path for rotary_embedding if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py index 766ae5f4a..e64bd6f64 100644 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -62,7 +62,7 @@ def rope_forward_oot( # adopt custom kernel path for rotary_embedding if custom_rotary_embedding_enabled(query, neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ca5132793..a16606122 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -24,7 +24,7 @@ import os from contextlib import contextmanager from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import torch import torch_npu # noqa: F401 # noqa: F401 @@ -188,7 +188,7 @@ def try_register_lib(lib_name: str, lib_info: str = ""): def enable_custom_op(): """ - Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. + Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device(). """ global _CUSTOM_OP_ENABLED @@ -486,7 +486,7 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): """Register Ascend CustomOP - NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, + NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, and ensure this will execute after model config is initilazed. """ global _ASCEND_CUSTOMOP_IS_REIGISTERED @@ -589,3 +589,31 @@ def dense_optim_enable() -> bool: def is_moe_model(vllm_config: VllmConfig): config = vllm_config.model_config.hf_config return any('experts' in key.lower() for key in config.to_dict()) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C_ascend.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors")