mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Fix the bugs about operator registration by PyTorch Dispatcher (#2786)
**Background:**
There are two principles about operator registration in PyTorch
- The same namespace can be only registered once by `TORCH_LIBRARY`
- The operator signatures can be only registered once by `def`
Considering that all custom operators defined in the current repo are
only used by Ascend, instead of defining a common operator schema by
vLLM, all accelerators then follow this operator schema and complete the
implementation based on their respective hardware, which is conducive to
functional abstraction.
Therefore, we can rename the operator registration namespace to an
Ascend-specific namespace(**_C_ascend**).
Related ISSUE: https://github.com/vllm-project/vllm-ascend/issues/2742
- vLLM version: main
- vLLM main:
f592b3174b
Signed-off-by: FFFrog <ljw1101.vip@gmail.com>
This commit is contained in:
@ -112,7 +112,7 @@ def test_get_masked_input_and_mask(
|
|||||||
|
|
||||||
# Define custom function
|
# Define custom function
|
||||||
def custom_fn():
|
def custom_fn():
|
||||||
return torch.ops._C.get_masked_input_and_mask(
|
return torch.ops._C_ascend.get_masked_input_and_mask(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
test_case["org_start"],
|
test_case["org_start"],
|
||||||
test_case["org_end"],
|
test_case["org_end"],
|
||||||
|
@ -141,7 +141,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
|||||||
TP2, rank 1:
|
TP2, rank 1:
|
||||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||||
Parameters:
|
Parameters:
|
||||||
org_vocab_start_index //base embeddings start
|
org_vocab_start_index //base embeddings start
|
||||||
org_vocab_end_index //base embeddings end
|
org_vocab_end_index //base embeddings end
|
||||||
@ -164,22 +164,22 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
|||||||
// Create output tensors
|
// Create output tensors
|
||||||
at::Tensor masked_input = at::empty_like(input);
|
at::Tensor masked_input = at::empty_like(input);
|
||||||
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
at::Tensor mask = at::empty_like(input).to(at::kBool);
|
||||||
|
|
||||||
// Get data pointers
|
// Get data pointers
|
||||||
void *input_ptr = input.data_ptr();
|
void *input_ptr = input.data_ptr();
|
||||||
void *masked_input_ptr = masked_input.data_ptr();
|
void *masked_input_ptr = masked_input.data_ptr();
|
||||||
void *mask_ptr = mask.data_ptr();
|
void *mask_ptr = mask.data_ptr();
|
||||||
|
|
||||||
// Get current stream
|
// Get current stream
|
||||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||||
|
|
||||||
// Get scalar type
|
// Get scalar type
|
||||||
at::ScalarType scalar_type = input.scalar_type();
|
at::ScalarType scalar_type = input.scalar_type();
|
||||||
|
|
||||||
// Create and configure OpCommand
|
// Create and configure OpCommand
|
||||||
at_npu::native::OpCommand cmd;
|
at_npu::native::OpCommand cmd;
|
||||||
cmd.Name("get_masked_input_and_mask");
|
cmd.Name("get_masked_input_and_mask");
|
||||||
cmd.SetCustomHandler([scalar_type, size, stream,
|
cmd.SetCustomHandler([scalar_type, size, stream,
|
||||||
input_ptr, masked_input_ptr, mask_ptr,
|
input_ptr, masked_input_ptr, mask_ptr,
|
||||||
org_vocab_start_index, org_vocab_end_index,
|
org_vocab_start_index, org_vocab_end_index,
|
||||||
num_org_vocab_padding, added_vocab_start_index,
|
num_org_vocab_padding, added_vocab_start_index,
|
||||||
@ -193,7 +193,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
|||||||
get_masked_input_and_mask_impl(
|
get_masked_input_and_mask_impl(
|
||||||
stream,
|
stream,
|
||||||
input_ptr,
|
input_ptr,
|
||||||
masked_input_ptr,
|
masked_input_ptr,
|
||||||
mask_ptr,
|
mask_ptr,
|
||||||
org_vocab_start_index,
|
org_vocab_start_index,
|
||||||
org_vocab_end_index,
|
org_vocab_end_index,
|
||||||
@ -203,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
|||||||
size,
|
size,
|
||||||
loop_cnt,
|
loop_cnt,
|
||||||
aiv_num);
|
aiv_num);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
});
|
});
|
||||||
cmd.Run();
|
cmd.Run();
|
||||||
@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
|||||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||||
at_npu::native::OpCommand cmd;
|
at_npu::native::OpCommand cmd;
|
||||||
cmd.Name("sgmv_shrink");
|
cmd.Name("sgmv_shrink");
|
||||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
||||||
seq_len_ptr, seq_len_size, y_ptr,
|
seq_len_ptr, seq_len_size, y_ptr,
|
||||||
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
||||||
auto dtype = get_dtype_from_torch(scalar_type);
|
auto dtype = get_dtype_from_torch(scalar_type);
|
||||||
int device_id = 0;
|
int device_id = 0;
|
||||||
@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
|||||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
|
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
|
||||||
y_ptr, batch_size,
|
y_ptr, batch_size,
|
||||||
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
||||||
return 0;
|
return 0;
|
||||||
});
|
});
|
||||||
@ -367,7 +367,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
|||||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||||
at_npu::native::OpCommand cmd;
|
at_npu::native::OpCommand cmd;
|
||||||
cmd.Name("sgmv_expand");
|
cmd.Name("sgmv_expand");
|
||||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||||
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
||||||
auto dtype = get_dtype_from_torch(scalar_type);
|
auto dtype = get_dtype_from_torch(scalar_type);
|
||||||
int device_id = 0;
|
int device_id = 0;
|
||||||
@ -375,7 +375,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
|||||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||||
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||||
return 0;
|
return 0;
|
||||||
});
|
});
|
||||||
@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
|||||||
}
|
}
|
||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||||
{
|
{
|
||||||
// vLLM-Ascend custom ops
|
// vLLM-Ascend custom ops
|
||||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||||
|
@ -40,7 +40,7 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_meta(
|
|||||||
at::Tensor &positions,
|
at::Tensor &positions,
|
||||||
at::Tensor &query,
|
at::Tensor &query,
|
||||||
at::Tensor &key,
|
at::Tensor &key,
|
||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
at::Tensor &cos_sin_cache,
|
at::Tensor &cos_sin_cache,
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
auto num_tokens = positions.sym_numel();
|
auto num_tokens = positions.sym_numel();
|
||||||
@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
|||||||
} // namespace vllm_ascend
|
} // namespace vllm_ascend
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
// Register the meta implementations of the custom kernels for symbolic tracing, this will also
|
||||||
// the custom kernel been captured into aclgraph
|
// the custom kernel been captured into aclgraph
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) {
|
TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||||
// Rotary embedding meta implementation
|
// Rotary embedding meta implementation
|
||||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||||
// Masked input and mask meta implementation
|
// Masked input and mask meta implementation
|
||||||
@ -99,4 +99,4 @@ namespace {
|
|||||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,8 +33,8 @@ def test_bgmv_expand():
|
|||||||
y_npu = y.npu()
|
y_npu = y.npu()
|
||||||
|
|
||||||
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
|
y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128)
|
||||||
y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0,
|
y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu,
|
||||||
128)
|
y_npu, 0, 128)
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
torch.testing.assert_close(y_out_npu.cpu(),
|
torch.testing.assert_close(y_out_npu.cpu(),
|
||||||
|
@ -33,7 +33,7 @@ def test_bgmv_shrink():
|
|||||||
y_npu = y.npu()
|
y_npu = y.npu()
|
||||||
|
|
||||||
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
|
y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5)
|
||||||
torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
|
torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5)
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
torch.testing.assert_close(y_npu.cpu(),
|
torch.testing.assert_close(y_npu.cpu(),
|
||||||
|
@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
query, key = torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -239,7 +239,7 @@ class ModelwithRotaryEmbedding(nn.Module):
|
|||||||
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
|
# we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.chunk(3, dim=-1)
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
query, key = torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -299,7 +299,7 @@ def test_capture_rotary_embedding_in_aclgraph(
|
|||||||
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
|
# Validate if the rotary_embedding custom kernel is indeed inside the graph by
|
||||||
# string match
|
# string match
|
||||||
graph = str(gm.graph)
|
graph = str(gm.graph)
|
||||||
assert "_C.rotary_embedding" in graph
|
assert "_C_ascend.rotary_embedding" in graph
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
static_positions = torch.randint(0, max_position_embeddings,
|
static_positions = torch.randint(0, max_position_embeddings,
|
||||||
|
@ -72,7 +72,7 @@ def test_get_masked_input_and_mask(
|
|||||||
|
|
||||||
# Get custom op result
|
# Get custom op result
|
||||||
print("input_tensor:", input_tensor)
|
print("input_tensor:", input_tensor)
|
||||||
custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask(
|
custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask(
|
||||||
input_tensor, test_case["org_start"], test_case["org_end"],
|
input_tensor, test_case["org_start"], test_case["org_end"],
|
||||||
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
test_case["padding"], test_case["added_start"], test_case["added_end"])
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
|||||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||||
self.mock_self.is_neox_style = self.is_neox_style
|
self.mock_self.is_neox_style = self.is_neox_style
|
||||||
|
|
||||||
@patch('torch.ops._C')
|
@patch('torch.ops._C_ascend')
|
||||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||||
return_value=True)
|
return_value=True)
|
||||||
|
@ -104,7 +104,7 @@ class TestRopeForwardOot(TestBase):
|
|||||||
self.assertTrue(torch.equal(result_q, self.query))
|
self.assertTrue(torch.equal(result_q, self.query))
|
||||||
self.assertTrue(torch.equal(result_k, self.key))
|
self.assertTrue(torch.equal(result_k, self.key))
|
||||||
|
|
||||||
@patch('torch.ops._C')
|
@patch('torch.ops._C_ascend')
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||||
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
|
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
|
||||||
|
@ -15,7 +15,8 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
|||||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import weak_ref_tensors
|
|
||||||
|
from ..utils import weak_ref_tensors
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -35,10 +36,10 @@ class ACLGraphWrapper:
|
|||||||
|
|
||||||
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
||||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||||
PIECEWISE).
|
PIECEWISE).
|
||||||
2. At runtime, the wrapper receives a runtime_mode and a
|
2. At runtime, the wrapper receives a runtime_mode and a
|
||||||
batch_descriptor(key) from the forward context and blindly trust them
|
batch_descriptor(key) from the forward context and blindly trust them
|
||||||
for aclgraph dispatching.
|
for aclgraph dispatching.
|
||||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||||
wrapper, just call the runnable directly.
|
wrapper, just call the runnable directly.
|
||||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||||
@ -47,9 +48,9 @@ class ACLGraphWrapper:
|
|||||||
|
|
||||||
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
||||||
runtime inputs into that buffers for replay. We assume implementing them
|
runtime inputs into that buffers for replay. We assume implementing them
|
||||||
is done outside of the wrapper. That is because we do not make any
|
is done outside of the wrapper. That is because we do not make any
|
||||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||||
tracing and checking the input addresses to be consistent during replay is
|
tracing and checking the input addresses to be consistent during replay is
|
||||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||||
"""
|
"""
|
||||||
|
@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor,
|
|||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
scaling: float = 1.0):
|
scaling: float = 1.0):
|
||||||
return torch.ops._C.bgmv_shrink(
|
return torch.ops._C_ascend.bgmv_shrink(
|
||||||
inputs,
|
inputs,
|
||||||
lora_a_weights,
|
lora_a_weights,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor,
|
|||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
add_inputs: bool = True):
|
add_inputs: bool = True):
|
||||||
return torch.ops._C.bgmv_expand(
|
return torch.ops._C_ascend.bgmv_expand(
|
||||||
inputs,
|
inputs,
|
||||||
lora_b_weights,
|
lora_b_weights,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
|||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = True):
|
add_inputs: bool = True):
|
||||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||||
lora_indices_tensor, output_tensor,
|
lora_indices_tensor, output_tensor,
|
||||||
slice_offset, slice_size)
|
slice_offset, slice_size)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_shrink(
|
def sgmv_shrink(
|
||||||
@ -69,9 +69,9 @@ def sgmv_shrink(
|
|||||||
token_nums: int,
|
token_nums: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
):
|
):
|
||||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
|
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||||
lora_indices_tensor, seq_len_tensor,
|
lora_indices_tensor, seq_len_tensor,
|
||||||
output_tensor, scaling)
|
output_tensor, scaling)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand(inputs: torch.Tensor,
|
def sgmv_expand(inputs: torch.Tensor,
|
||||||
@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor,
|
|||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
token_nums: int,
|
token_nums: int,
|
||||||
add_inputs: bool = False):
|
add_inputs: bool = False):
|
||||||
return torch.ops._C.sgmv_expand(
|
return torch.ops._C_ascend.sgmv_expand(
|
||||||
inputs,
|
inputs,
|
||||||
lora_b_weights,
|
lora_b_weights,
|
||||||
lora_indices_tensor,
|
lora_indices_tensor,
|
||||||
@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
|||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = False):
|
add_inputs: bool = False):
|
||||||
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
|
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||||
lora_indices_tensor, seq_len_tensor,
|
lora_indices_tensor, seq_len_tensor,
|
||||||
output_tensor, slice_offset, slice_size)
|
output_tensor, slice_offset,
|
||||||
|
slice_size)
|
||||||
|
@ -23,7 +23,7 @@ from torch.library import Library
|
|||||||
# Do NOT perform any real computation or allocate device memory.
|
# Do NOT perform any real computation or allocate device memory.
|
||||||
#
|
#
|
||||||
# 2. Register your meta function using `register_meta_if_necessary`, providing:
|
# 2. Register your meta function using `register_meta_if_necessary`, providing:
|
||||||
# - The namespace (usually "_C" for custom ops)
|
# - The namespace (usually "_C_ascend" for custom ops)
|
||||||
# - The operator name (as registered in C++)
|
# - The operator name (as registered in C++)
|
||||||
# - The Python meta function
|
# - The Python meta function
|
||||||
# - (Optional) The overload name, if your op has overloads
|
# - (Optional) The overload name, if your op has overloads
|
||||||
@ -39,7 +39,7 @@ from torch.library import Library
|
|||||||
#
|
#
|
||||||
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
|
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
|
||||||
|
|
||||||
lib = Library("_C", "IMPL")
|
lib = Library("_C_ascend", "IMPL")
|
||||||
|
|
||||||
|
|
||||||
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
|
||||||
@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
|
|||||||
return y_out
|
return y_out
|
||||||
|
|
||||||
|
|
||||||
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
|
register_meta_if_necessary("_C_ascend", "rotary_embedding",
|
||||||
register_meta_if_necessary("_C", "get_masked_input_and_mask",
|
rotary_embedding_meta)
|
||||||
|
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
|
||||||
get_masked_input_and_mask_meta)
|
get_masked_input_and_mask_meta)
|
||||||
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
|
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
|
||||||
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
|
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)
|
||||||
|
@ -35,19 +35,20 @@ class dummyFusionOp:
|
|||||||
|
|
||||||
|
|
||||||
def register_dummy_fusion_op() -> None:
|
def register_dummy_fusion_op() -> None:
|
||||||
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
|
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
|
||||||
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
|
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
|
||||||
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
|
name="fused_add_rms_norm")
|
||||||
|
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
|
||||||
name="static_scaled_fp8_quant")
|
name="static_scaled_fp8_quant")
|
||||||
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
|
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
|
||||||
name="dynamic_scaled_fp8_quant")
|
name="dynamic_scaled_fp8_quant")
|
||||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
|
||||||
name="dynamic_per_token_scaled_fp8_quant")
|
name="dynamic_per_token_scaled_fp8_quant")
|
||||||
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
|
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
|
||||||
name="rms_norm_static_fp8_quant")
|
name="rms_norm_static_fp8_quant")
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
|
||||||
name="fused_add_rms_norm_static_fp8_quant")
|
name="fused_add_rms_norm_static_fp8_quant")
|
||||||
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
|
||||||
name="rms_norm_dynamic_per_token_quant")
|
name="rms_norm_dynamic_per_token_quant")
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ def _rope_forward_oot(
|
|||||||
# adopt custom kernel path for rotary_embedding
|
# adopt custom kernel path for rotary_embedding
|
||||||
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
||||||
self.head_size) and not is_310p():
|
self.head_size) and not is_310p():
|
||||||
query, key = torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
@ -62,7 +62,7 @@ def rope_forward_oot(
|
|||||||
# adopt custom kernel path for rotary_embedding
|
# adopt custom kernel path for rotary_embedding
|
||||||
if custom_rotary_embedding_enabled(query, neox_style,
|
if custom_rotary_embedding_enabled(query, neox_style,
|
||||||
self.head_size) and not is_310p():
|
self.head_size) and not is_310p():
|
||||||
query, key = torch.ops._C.rotary_embedding(
|
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
@ -24,7 +24,7 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu # noqa: F401 # noqa: F401
|
import torch_npu # noqa: F401 # noqa: F401
|
||||||
@ -188,7 +188,7 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
|
|||||||
|
|
||||||
def enable_custom_op():
|
def enable_custom_op():
|
||||||
"""
|
"""
|
||||||
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
||||||
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
||||||
"""
|
"""
|
||||||
global _CUSTOM_OP_ENABLED
|
global _CUSTOM_OP_ENABLED
|
||||||
@ -486,7 +486,7 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
|||||||
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||||
"""Register Ascend CustomOP
|
"""Register Ascend CustomOP
|
||||||
|
|
||||||
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
|
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
|
||||||
and ensure this will execute after model config is initilazed.
|
and ensure this will execute after model config is initilazed.
|
||||||
"""
|
"""
|
||||||
global _ASCEND_CUSTOMOP_IS_REIGISTERED
|
global _ASCEND_CUSTOMOP_IS_REIGISTERED
|
||||||
@ -589,3 +589,31 @@ def dense_optim_enable() -> bool:
|
|||||||
def is_moe_model(vllm_config: VllmConfig):
|
def is_moe_model(vllm_config: VllmConfig):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
return any('experts' in key.lower() for key in config.to_dict())
|
return any('experts' in key.lower() for key in config.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensor(tensor: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Create a weak reference to a tensor.
|
||||||
|
The new tensor will share the same data as the original tensor,
|
||||||
|
but will not keep the original tensor alive.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
|
return torch.ops._C_ascend.weak_ref_tensor(tensor)
|
||||||
|
else:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def weak_ref_tensors(
|
||||||
|
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
|
||||||
|
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
||||||
|
"""
|
||||||
|
Convenience function to create weak references to tensors,
|
||||||
|
for single tensor, list of tensors or tuple of tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
return weak_ref_tensor(tensors)
|
||||||
|
if isinstance(tensors, list):
|
||||||
|
return [weak_ref_tensor(t) for t in tensors]
|
||||||
|
if isinstance(tensors, tuple):
|
||||||
|
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||||
|
raise ValueError("Invalid type for tensors")
|
||||||
|
Reference in New Issue
Block a user