mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat] Unquantized Linear to nz and control all nz-cast (#3356)
### What this PR does / why we need it? Currently, when executing to the Linear layer of models in vLLM-Ascend, the weights format is ND in unquantized case and skipped ascend case. This PR supplements the execution logic for Linear layer. We use a new global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as w8a8-quantized case. ### Does this PR introduce _any_ user-facing change? Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ format, you should set VLLM_ASCEND_ENABLE_NZ=1. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@ -376,7 +376,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_format_cast):
|
||||
layer = MagicMock(spec=LinearBase)
|
||||
layer.input_size_per_partition = 10
|
||||
quant_method = MagicMock()
|
||||
@ -389,6 +390,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
layer.weight = torch.randn(shape_0, shape_1)
|
||||
self.impl.kv_b_proj = layer
|
||||
apply.return_value = layer.weight.T
|
||||
mock_format_cast.return_value = layer.weight
|
||||
self.impl.process_weights_after_loading(torch.bfloat16)
|
||||
|
||||
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -20,6 +20,7 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
from vllm_ascend import ascend_config
|
||||
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2RowParallelLinear)
|
||||
|
||||
@ -46,6 +47,13 @@ def test_row_parallel_linear(cls, mock_distributed):
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
mock_distributed, base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
# Make a fake ascend config because of the AscendLinearBase
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.additional_config = None
|
||||
vllm_config.parallel_config.enable_expert_parallel = False
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
vllm_config.kv_transfer_config = None
|
||||
ascend_config.init_ascend_config(vllm_config)
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
@ -78,6 +86,7 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
ascend_config._ASCEND_CONFIG = None
|
||||
|
||||
|
||||
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
|
||||
@ -90,6 +99,14 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
|
||||
|
||||
config = SimpleConfig()
|
||||
|
||||
# Make a fake ascend config because of the AscendLinearBase
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.additional_config = None
|
||||
vllm_config.parallel_config.enable_expert_parallel = False
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
vllm_config.kv_transfer_config = None
|
||||
ascend_config.init_ascend_config(vllm_config)
|
||||
|
||||
# 直接创建lmhead和logits_processor
|
||||
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
logits_processor = LogitsProcessor(config.vocab_size)
|
||||
@ -105,3 +122,4 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
|
||||
return_value=mock_logits):
|
||||
logits = logits_processor(lmhead, mock_output)
|
||||
assert logits.shape == (2, 4, config.vocab_size)
|
||||
ascend_config._ASCEND_CONFIG = None
|
||||
|
@ -5,10 +5,13 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend import ascend_config
|
||||
from vllm_ascend.distributed import parallel_state
|
||||
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear,
|
||||
AscendUnquantizedLinearMethod)
|
||||
|
||||
|
||||
class BaseLinearTest(unittest.TestCase):
|
||||
@ -49,6 +52,47 @@ class BaseLinearTest(unittest.TestCase):
|
||||
p.stop()
|
||||
|
||||
|
||||
class TestAscendUnquantizedLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendUnquantizedLinearMethod()
|
||||
|
||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||
@mock.patch("torch_npu.npu_format_cast")
|
||||
@mock.patch("torch.version")
|
||||
def test_process_weights_after_loading_is_8_3_enable_nz(
|
||||
self, mock_version, mock_format_cast, mock_is_nz):
|
||||
layer = mock.MagicMock()
|
||||
|
||||
mock_version.cann = "8.3.RC1"
|
||||
mock_is_nz.return_value = 1
|
||||
self.method.process_weights_after_loading(layer)
|
||||
mock_format_cast.assert_called_once()
|
||||
|
||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||
@mock.patch("torch_npu.npu_format_cast")
|
||||
@mock.patch("torch.version")
|
||||
def test_process_weights_after_loading_is_8_3_disable_nz(
|
||||
self, mock_version, mock_format_cast, mock_is_nz):
|
||||
layer = mock.MagicMock()
|
||||
|
||||
mock_version.cann = "8.3.RC1"
|
||||
mock_is_nz.return_value = 0
|
||||
self.method.process_weights_after_loading(layer)
|
||||
mock_format_cast.assert_not_called()
|
||||
|
||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||
@mock.patch("torch.version")
|
||||
def test_process_weights_after_loading_not_8_3(self, mock_version,
|
||||
mock_is_nz):
|
||||
layer = mock.MagicMock()
|
||||
|
||||
mock_version.cann = "8.2.RC1"
|
||||
mock_is_nz.return_value = 1
|
||||
# Should not raise exception
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
|
||||
class TestAscendRowParallelLinear(BaseLinearTest):
|
||||
|
||||
def test_mlp_optimize(self):
|
||||
@ -92,5 +136,24 @@ class TestAscendMergedColumnParallelLinear(BaseLinearTest):
|
||||
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
|
||||
|
||||
|
||||
class TestAscendReplicatedLinear(BaseLinearTest):
|
||||
|
||||
def test_init_disable_tp(self):
|
||||
linear = AscendReplicatedLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
)
|
||||
self.assertTrue(
|
||||
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
|
||||
|
||||
def test_init_without_disable_tp(self):
|
||||
linear = AscendReplicatedLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
)
|
||||
self.assertTrue(
|
||||
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -4,10 +4,10 @@ import torch
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
|
||||
AscendQuantConfig)
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
@ -82,7 +82,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIsInstance(method, UnquantizedLinearMethod)
|
||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
|
@ -137,8 +137,10 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_npu_format_cast):
|
||||
def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
|
||||
mock_is_nz):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randn(128, 256)
|
||||
@ -148,6 +150,7 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_is_nz.return_value = 0
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
@ -160,6 +163,35 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_not_called()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
|
||||
mock_is_nz):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randn(128, 256)
|
||||
layer.input_scale.data = torch.tensor([0.1])
|
||||
layer.input_offset.data = torch.tensor([0])
|
||||
layer.deq_scale = torch.tensor([0.5])
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_is_nz.return_value = 1
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
||||
self.assertTrue(
|
||||
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
||||
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
||||
|
||||
self.assertFalse(layer.deq_scale.requires_grad)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
@ -39,6 +39,14 @@ class TestUtils(TestBase):
|
||||
"Ascend910P1"):
|
||||
self.assertFalse(utils.is_310p())
|
||||
|
||||
def test_is_enable_nz(self):
|
||||
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
|
||||
1):
|
||||
self.assertTrue(utils.is_enable_nz())
|
||||
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
|
||||
0):
|
||||
self.assertFalse(utils.is_enable_nz())
|
||||
|
||||
def test_sleep_mode_enabled(self):
|
||||
utils._SLEEP_MODE_ENABLED = None
|
||||
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
|
||||
|
@ -96,15 +96,17 @@ class TestTorchairUtils(TestBase):
|
||||
self.assertEqual(args[0], expected_name)
|
||||
self.assertEqual(args[1], expected_path)
|
||||
|
||||
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format(self, mock_npu_cast,
|
||||
mock_get_format):
|
||||
def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
|
||||
mock_get_format, mock_is_nz):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = 1
|
||||
mock_npu_cast.return_value = 1
|
||||
mock_is_nz.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
@ -137,3 +139,26 @@ class TestTorchairUtils(TestBase):
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
||||
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
|
||||
mock_get_format, mock_is_nz):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = 1
|
||||
mock_npu_cast.return_value = 1
|
||||
mock_is_nz.return_value = 0
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
@ -27,6 +27,8 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -595,6 +597,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
# Weight will be reshaped next. To be on the safe side, the format
|
||||
# of the weight should be reverted to FRACTAL_AND.
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
@ -623,6 +629,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Function `get_and_maybe_dequant_weights` will cast the weights to
|
||||
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
|
||||
if is_enable_nz():
|
||||
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
|
||||
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
@ -169,6 +169,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
"VLLM_ASCEND_ENABLE_MLAPO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
||||
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
|
||||
"VLLM_ASCEND_ENABLE_NZ":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -32,13 +32,15 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -57,16 +59,81 @@ from vllm.model_executor.models.deepseek_v2 import (
|
||||
from vllm.model_executor.models.utils import (PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
||||
AscendSparseFlashAttention, Indexer)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.linear import AscendLinearBase
|
||||
|
||||
|
||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
|
@ -37,7 +37,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
|
||||
npu_stream_switch)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@ -83,7 +84,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p():
|
||||
if not is_310p() and is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
@ -24,17 +24,29 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import divide
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
||||
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||
RowParallelLinear, UnquantizedLinearMethod)
|
||||
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
"""Linear method without quantization"""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
if is_enable_nz() and torch.version.cann.startswith("8.3"):
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
||||
@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase):
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendReplicatedLinear(ReplicatedLinear):
|
||||
"""Ascend Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op = get_replicated_op(disable_tp, prefix, self)
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
|
||||
communication groups and forward functions into classes (linear ops).
|
||||
|
||||
Current class inheritance structure:
|
||||
CustomTensorParallelOp
|
||||
CustomLinearOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── SequenceColumnParallelOp
|
||||
└── CustomRowParallelOp
|
||||
├── MLPRowParallelOp
|
||||
├── OProjRowParallelOp
|
||||
├── MatmulAllreduceRowParallelOp
|
||||
└── SequenceRowParallelOp
|
||||
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
oproj_tp_enable)
|
||||
|
||||
|
||||
class CustomTensorParallelOp:
|
||||
class CustomLinearOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
@ -95,7 +95,7 @@ class CustomTensorParallelOp:
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
self.gather_output = self.layer.gather_output
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
def apply_impl(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self.layer, input_, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_replicated_op(disable_tp, prefix,
|
||||
layer) -> Optional[Union[CustomReplicatedOp]]:
|
||||
if disable_tp:
|
||||
return None
|
||||
|
||||
return CustomReplicatedOp(layer)
|
||||
|
@ -24,8 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -39,6 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
@ -101,7 +101,7 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return AscendUnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, Attention) and \
|
||||
|
@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@ -393,9 +393,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
@ -156,8 +156,9 @@ class AscendW8A8LinearMethod:
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
@ -340,7 +341,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p():
|
||||
if not is_310p() and is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
@ -101,8 +101,9 @@ class AscendW8A8DynamicLinearMethod:
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@ -267,8 +268,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
is_enable_nz,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod:
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, 29)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
if is_enable_nz():
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
-1)
|
||||
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
if is_enable_nz():
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
|
||||
qb_deq_scl = self.q_proj.deq_scale.data.clone()
|
||||
qb_deq_scl = qb_deq_scl.reshape(
|
||||
|
@ -14,6 +14,7 @@ try:
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format):
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
@ -65,6 +65,10 @@ def is_310p():
|
||||
return _IS_310P
|
||||
|
||||
|
||||
def is_enable_nz():
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
|
||||
|
||||
|
||||
def sleep_mode_enabled():
|
||||
global _SLEEP_MODE_ENABLED
|
||||
if _SLEEP_MODE_ENABLED is None:
|
||||
@ -508,6 +512,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
|
||||
@ -526,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
|
||||
"MergedColumnParallelLinear": AscendMergedColumnParallelLinear,
|
||||
"QKVParallelLinear": AscendQKVParallelLinear,
|
||||
"ReplicatedLinear": AscendReplicatedLinear,
|
||||
"DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,
|
||||
"VocabParallelEmbedding": AscendVocabParallelEmbedding,
|
||||
"ParallelLMHead": AscendParallelLMHead,
|
||||
|
@ -97,6 +97,7 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
@ -125,7 +126,7 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p,
|
||||
get_ascend_soc_version, is_310p, is_enable_nz,
|
||||
lmhead_tp_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@ -137,8 +138,6 @@ else:
|
||||
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
@ -2609,6 +2608,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return tensor
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
|
||||
|
Reference in New Issue
Block a user