Files
vllm-ascend/tests/ut/quantization/test_quant_config.py
anon189Ty 07e39620ea [Feat] Unquantized Linear to nz and control all nz-cast (#3356)
### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-10-14 17:39:26 +08:00

234 lines
10 KiB
Python

from unittest.mock import MagicMock, patch
import torch
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
class TestAscendQuantConfig(TestBase):
def setUp(self):
self.sample_config = {
"weight": "INT8",
"fa_quant_type": "C8",
"kv_quant_type": "C8",
"layer1.weight": "INT8",
"layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT",
"fused_layer.shard1.weight": "FLOAT",
"fused_layer.shard2.weight": "FLOAT",
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendQuantConfig(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_init(self):
self.assertEqual(self.ascend_config.quant_description,
self.sample_config)
def test_repr(self):
repr_str = repr(self.ascend_config)
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
def test_get_name(self):
self.assertEqual(AscendQuantConfig.get_name(),
ASCEND_QUANTIZATION_METHOD)
def test_get_supported_act_dtypes(self):
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
self.assertEqual(len(supported_dtypes), 3)
def test_get_min_capability(self):
with self.assertRaises(NotImplementedError):
AscendQuantConfig.get_min_capability()
def test_get_config_filenames(self):
filenames = AscendQuantConfig.get_config_filenames()
self.assertEqual(filenames, ["quant_model_description.json"])
def test_from_config(self):
config = AscendQuantConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendQuantConfig)
self.assertEqual(config.quant_description, self.sample_config)
@patch('torch.npu.is_available')
def test_override_quantization_method(self, mock_is_available):
# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
# Test when NPU is not available
mock_is_available.return_value = False
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertIsNone(result)
def test_get_quant_method_for_linear(self):
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch.object(self.ascend_config, \
'is_layer_skipped_ascend',
return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(
self.ascend_config, ".attn",
self.ascend_config.packed_modules_mapping)
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with fa_quant_type
method = self.ascend_config.get_quant_method(
attention_layer, ".attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with kv_quant_type
modified_config = {"kv_quant_type": "C8"}
config = AscendQuantConfig(modified_config)
config.packed_modules_mapping = None
method = config.get_quant_method(attention_layer, "attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
mock_config = MagicMock()
mock_config.model_config.hf_config.model_type = None
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
def test_is_layer_skipped_ascend(self):
# Test non-fused layer that should be quantized
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
# Test non-fused layer that should be skipped
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
# Test fused layer
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
self.assertTrue(
self.ascend_config.is_layer_skipped_ascend("fused_layer",
fused_mapping))
# Test inconsistent fused layer shards
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
config = AscendQuantConfig(bad_config)
with self.assertRaises(ValueError):
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
def test_get_scaled_act_names(self):
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
class TestAscendKVCacheMethod(TestBase):
def setUp(self):
# Setup common test fixtures
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
self.prefix = "layer.attn"
# Mock quant_method
self.mock_quant_method = MagicMock()
self.patcher = patch(
'vllm_ascend.quantization.quant_config.get_quant_method')
self.mock_get_quant_method = self.patcher.start()
self.mock_get_quant_method.return_value = self.mock_quant_method
# Create instance
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
self.prefix)
def tearDown(self):
self.patcher.stop()
def test_create_weights(self):
"""Test create_weights delegates to quant_method."""
mock_layer = MagicMock()
self.kv_cache_method.create_weights(mock_layer)
self.mock_quant_method.create_weights.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_with_method(self):
"""Test process_weights when quant_method has the method."""
mock_layer = MagicMock()
self.kv_cache_method.process_weights_after_loading(mock_layer)
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_without_method(self):
"""Test process_weights when quant_method lacks the method."""
# Reset mock to remove the method
del self.mock_quant_method.process_weights_after_loading
mock_layer = MagicMock()
# Should not raise exception
self.kv_cache_method.process_weights_after_loading(mock_layer)
def test_apply_delegation(self):
"""Test apply properly delegates to quant_method."""
mock_layer = MagicMock()
mock_query = torch.randn(1, 32, 128)
mock_key = torch.randn(1, 32, 128)
mock_value = torch.randn(1, 32, 128)
mock_kv_cache = MagicMock()
mock_attn_metadata = MagicMock()
mock_scale = 1.0
mock_output = torch.zeros(1, 32, 128)
mock_attn_type = MagicMock()
expected_result = torch.randn(1, 32, 128)
self.mock_quant_method.apply.return_value = expected_result
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type,
mock_scale, mock_output)
self.mock_quant_method.apply.assert_called_once_with(
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
self.assertTrue(torch.equal(result, expected_result))