[Feat] Unquantized Linear to nz and control all nz-cast (#3356)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-14 17:39:26 +08:00
committed by GitHub
parent 5c45c227dc
commit 07e39620ea
22 changed files with 413 additions and 49 deletions

View File

@ -376,7 +376,8 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_format_cast):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
@ -389,6 +390,7 @@ class TestAscendMLAImpl(TestBase):
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
mock_format_cast.return_value = layer.weight
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)

View File

@ -12,7 +12,7 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
@ -20,6 +20,7 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_ascend import ascend_config
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
CustomDeepseekV2RowParallelLinear)
@ -46,6 +47,13 @@ def test_row_parallel_linear(cls, mock_distributed):
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
vllm_config = MagicMock()
vllm_config.additional_config = None
vllm_config.parallel_config.enable_expert_parallel = False
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
@ -78,6 +86,7 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
ascend_config._ASCEND_CONFIG = None
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
@ -90,6 +99,14 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
config = SimpleConfig()
# Make a fake ascend config because of the AscendLinearBase
vllm_config = MagicMock()
vllm_config.additional_config = None
vllm_config.parallel_config.enable_expert_parallel = False
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
# 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size)
@ -105,3 +122,4 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
return_value=mock_logits):
logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size)
ascend_config._ASCEND_CONFIG = None

View File

@ -5,10 +5,13 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase):
@ -49,6 +52,47 @@ class BaseLinearTest(unittest.TestCase):
p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(layer)
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
@ -92,5 +136,24 @@ class TestAscendMergedColumnParallelLinear(BaseLinearTest):
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__':
unittest.main()

View File

@ -4,10 +4,10 @@ import torch
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@ -82,7 +82,7 @@ class TestAscendQuantConfig(TestBase):
'is_layer_skipped_ascend',
return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, UnquantizedLinearMethod)
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \

View File

@ -137,8 +137,10 @@ class TestAscendW8A8LinearMethod(TestBase):
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_npu_format_cast):
def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
@ -148,6 +150,7 @@ class TestAscendW8A8LinearMethod(TestBase):
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 0
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
@ -160,6 +163,35 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 1
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
self.assertTrue(
torch.equal(layer.aclnn_input_offset.data, expected_offset))
self.assertFalse(layer.aclnn_input_offset.requires_grad)
self.assertFalse(layer.deq_scale.requires_grad)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
class TestAscendW8A8FusedMoEMethod(TestBase):

View File

@ -39,6 +39,14 @@ class TestUtils(TestBase):
"Ascend910P1"):
self.assertFalse(utils.is_310p())
def test_is_enable_nz(self):
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
1):
self.assertTrue(utils.is_enable_nz())
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
0):
self.assertFalse(utils.is_enable_nz())
def test_sleep_mode_enabled(self):
utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",

View File

@ -96,15 +96,17 @@ class TestTorchairUtils(TestBase):
self.assertEqual(args[0], expected_name)
self.assertEqual(args[1], expected_path)
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format(self, mock_npu_cast,
mock_get_format):
def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
@ -137,3 +139,26 @@ class TestTorchairUtils(TestBase):
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 0
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()

View File

@ -27,6 +27,8 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@ -595,6 +597,10 @@ class AscendMLAImpl(MLAAttentionImpl):
del eye
# standardize to (output, input)
return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight
# we currently do not have quantized bmm's which are needed for
@ -623,6 +629,12 @@ class AscendMLAImpl(MLAAttentionImpl):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)

View File

@ -169,6 +169,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
"VLLM_ASCEND_ENABLE_NZ":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
}
# end-env-vars-definition

View File

@ -32,13 +32,15 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -57,16 +59,81 @@ from vllm.model_executor.models.deepseek_v2 import (
from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def forward(
self,
input_,

View File

@ -37,7 +37,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
npu_stream_switch)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@ -83,7 +84,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p():
if not is_310p() and is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(

View File

@ -24,17 +24,29 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch_npu
from torch.nn.parameter import Parameter
from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization"""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if is_enable_nz() and torch.version.cann.startswith("8.3"):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase):
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
return self.custom_op.apply(input_)
return super().forward(input_)
class AscendReplicatedLinear(ReplicatedLinear):
"""Ascend Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op = get_replicated_op(disable_tp, prefix, self)
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
return super().forward(input_)

View File

@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops).
Current class inheritance structure:
CustomTensorParallelOp
CustomLinearOp
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── SequenceColumnParallelOp
└── CustomRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── SequenceRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── SequenceRowParallelOp
└── CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
oproj_tp_enable)
class CustomTensorParallelOp:
class CustomLinearOp:
def __init__(self, layer):
self.layer = layer
@ -95,7 +95,7 @@ class CustomTensorParallelOp:
return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp):
class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomTensorParallelOp):
class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
return output, output_bias
class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix,
layer) -> Optional[Union[CustomReplicatedOp]]:
if disable_tp:
return None
return CustomReplicatedOp(layer)

View File

@ -24,8 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod)
RowParallelLinear)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
@ -39,6 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)
@ -101,7 +101,7 @@ class AscendQuantConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedLinearMethod()
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, Attention) and \

View File

@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendW4A8DynamicLinearMethod:
@ -393,9 +393,10 @@ class AscendW4A8DynamicFusedMoEMethod:
self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
if is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

View File

@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
def quant_per_tensor(in_tensor: torch.Tensor,
@ -156,8 +156,9 @@ class AscendW8A8LinearMethod:
requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
ACL_FORMAT_FRACTAL_NZ)
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
@ -340,7 +341,7 @@ class AscendW8A8FusedMoEMethod:
# converting ACL_FORMAT_FRACTAL_NZ.
# npu_quant_grouped_matmul_dequant in eager mode does not accept
# ACL_FORMAT_FRACTAL_NZ.
if not is_310p():
if not is_310p() and is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast(

View File

@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendW8A8DynamicLinearMethod:
@ -101,8 +101,9 @@ class AscendW8A8DynamicLinearMethod:
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format for higher inference speed
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
ACL_FORMAT_FRACTAL_NZ)
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
@ -267,8 +268,9 @@ class AscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
if is_enable_nz():
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

View File

@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version,
is_enable_nz,
is_hierarchical_communication_enabled)
@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod:
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format (29) for higher inference speed
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, 29)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
if is_enable_nz():
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

View File

@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import is_enable_nz
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
if is_enable_nz():
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
if is_enable_nz():
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data.clone()
qb_deq_scl = qb_deq_scl.reshape(

View File

@ -14,6 +14,7 @@ try:
except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format):
if isinstance(module, FusedMoE):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
if format == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
module.w2_weight.data = torch_npu.npu_format_cast(

View File

@ -65,6 +65,10 @@ def is_310p():
return _IS_310P
def is_enable_nz():
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
def sleep_mode_enabled():
global _SLEEP_MODE_ENABLED
if _SLEEP_MODE_ENABLED is None:
@ -508,6 +512,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
@ -526,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
"MergedColumnParallelLinear": AscendMergedColumnParallelLinear,
"QKVParallelLinear": AscendQKVParallelLinear,
"ReplicatedLinear": AscendReplicatedLinear,
"DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,
"VocabParallelEmbedding": AscendVocabParallelEmbedding,
"ParallelLMHead": AscendParallelLMHead,

View File

@ -97,6 +97,7 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
@ -125,7 +126,7 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p,
get_ascend_soc_version, is_310p, is_enable_nz,
lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@ -137,8 +138,6 @@ else:
import torch_npu
import vllm_ascend.envs as envs_ascend
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
torch.npu.config.allow_internal_format = True
@ -2609,6 +2608,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor):
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor