mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
Cleans up the initial implementation of `w4a4_flatquant_dynamic` for better readability and maintainability. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
247 lines
12 KiB
Python
247 lines
12 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
|
|
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
|
|
pack_int4_weights)
|
|
|
|
|
|
class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
|
"""
|
|
Unit test suite for AscendW4A4FlatQuantDynamicLinearMethod and its helper functions.
|
|
"""
|
|
|
|
def setUp(self):
|
|
"""Set up the test environment before each test."""
|
|
self.method = AscendW4A4FlatQuantDynamicLinearMethod()
|
|
self.output_size = 64
|
|
self.input_size = 768 # 768 = 24 * 32, divisible by 8
|
|
self.params_dtype = torch.float16
|
|
|
|
## Test Helper Functions
|
|
## --------------------
|
|
|
|
def test_get_decompose_dim(self):
|
|
"""
|
|
Tests the get_decompose_dim function with various inputs.
|
|
"""
|
|
self.assertEqual(get_decompose_dim(1024), (32, 32))
|
|
self.assertEqual(get_decompose_dim(768), (24, 32))
|
|
self.assertEqual(get_decompose_dim(100), (10, 10))
|
|
self.assertEqual(get_decompose_dim(99), (9, 11))
|
|
|
|
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
|
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
|
|
"""
|
|
Tests weight packing using the mocked NPU kernel.
|
|
"""
|
|
weight_tensor = torch.randn(self.output_size, self.input_size)
|
|
mock_packed_tensor = torch.randint(
|
|
0,
|
|
100, (self.output_size, self.input_size // 8),
|
|
dtype=torch.int32)
|
|
mock_npu_tensor = MagicMock()
|
|
mock_npu_tensor.to.return_value = mock_packed_tensor
|
|
mock_torch_npu.npu_convert_weight_to_int4pack.return_value = mock_npu_tensor
|
|
with patch('torch.Tensor.npu', return_value=weight_tensor):
|
|
result = pack_int4_weights(weight_tensor)
|
|
|
|
mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once()
|
|
self.assertTrue(torch.equal(result, mock_packed_tensor))
|
|
|
|
## Test AscendW4A4FlatQuantDynamicLinearMethod Class
|
|
## --------------------------------------------------
|
|
|
|
def test_get_weight(self):
|
|
"""Tests the get_weight static method for correct output."""
|
|
params = self.method.get_weight(self.input_size, self.output_size,
|
|
self.params_dtype)
|
|
self.assertIn("weight", params)
|
|
self.assertEqual(params["weight"].shape,
|
|
(self.output_size, self.input_size))
|
|
self.assertEqual(params["weight"].dtype, torch.int8)
|
|
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size,
|
|
self.input_size)
|
|
|
|
def test_get_weight_value_error(self):
|
|
"""Tests that get_weight raises ValueError for invalid input_size."""
|
|
with self.assertRaisesRegex(ValueError, "must be divisible by 8"):
|
|
self.method.get_weight(127, self.output_size, self.params_dtype)
|
|
|
|
def test_get_pertensor_param(self):
|
|
"""Tests the get_pertensor_param static method."""
|
|
self.method.get_weight(self.input_size, self.output_size,
|
|
self.params_dtype)
|
|
params = self.method.get_pertensor_param(self.params_dtype)
|
|
left_dim, right_dim = get_decompose_dim(self.input_size)
|
|
self.assertIn("left_trans", params)
|
|
self.assertIn("right_trans", params)
|
|
self.assertIn("clip_ratio", params)
|
|
self.assertEqual(params["left_trans"].shape, (left_dim, left_dim))
|
|
self.assertEqual(params["right_trans"].shape, (right_dim, right_dim))
|
|
self.assertEqual(params["clip_ratio"].shape, (1, ))
|
|
self.assertEqual(params["left_trans"].dtype, self.params_dtype)
|
|
self.assertEqual(params["clip_ratio"].dtype, torch.float32)
|
|
|
|
def test_get_perchannel_param(self):
|
|
"""Tests the get_perchannel_param static method."""
|
|
params = self.method.get_perchannel_param(self.output_size,
|
|
self.params_dtype)
|
|
self.assertIn("weight_scale", params)
|
|
self.assertIn("weight_offset", params)
|
|
self.assertEqual(params["weight_scale"].shape, (self.output_size, 1))
|
|
self.assertEqual(params["weight_offset"].shape, (self.output_size, 1))
|
|
self.assertEqual(params["weight_scale"].dtype, torch.float32)
|
|
self.assertEqual(params["weight_offset"].dtype, torch.float32)
|
|
|
|
def test_get_pergroup_param(self):
|
|
"""Tests the get_pergroup_param method."""
|
|
params = self.method.get_pergroup_param(self.input_size,
|
|
self.output_size,
|
|
self.params_dtype)
|
|
self.assertEqual(params, {})
|
|
|
|
def _prepare_apply_mocks_and_layer(self, batch_size):
|
|
"""Helper to create a mock layer and input tensor for apply tests."""
|
|
layer = nn.Module()
|
|
m, n = get_decompose_dim(self.input_size)
|
|
layer.left_trans = torch.randn(m, m, dtype=self.params_dtype)
|
|
layer.right_trans = torch.randn(n, n, dtype=self.params_dtype)
|
|
layer.aclnn_clip_ratio = 0.95
|
|
layer.weight_packed = torch.randint(
|
|
-8, 7, (self.output_size, self.input_size // 8), dtype=torch.int32)
|
|
layer.weight_scale = torch.randn(self.output_size,
|
|
1,
|
|
dtype=torch.float32)
|
|
x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype)
|
|
return layer, x, m, n
|
|
|
|
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
|
def test_apply_small_batch(self, mock_torch_npu):
|
|
"""Tests the apply method with a batch size smaller than MAX_BATCH_SIZE."""
|
|
batch_size = 128
|
|
layer, x, m, n = self._prepare_apply_mocks_and_layer(batch_size)
|
|
mock_quant_x = torch.randint(0,
|
|
255, (batch_size, self.input_size // 8),
|
|
dtype=torch.int32)
|
|
mock_act_scale = torch.randn(batch_size, 1, dtype=torch.float32)
|
|
mock_torch_npu.npu_kronecker_quant.return_value = (mock_quant_x.view(
|
|
batch_size, m, n // 8), mock_act_scale)
|
|
mock_output = torch.randn(batch_size,
|
|
self.output_size,
|
|
dtype=self.params_dtype)
|
|
mock_torch_npu.npu_quant_matmul.return_value = mock_output
|
|
bias = torch.randn(self.output_size, dtype=self.params_dtype)
|
|
output = self.method.apply(layer, x, bias=bias)
|
|
mock_torch_npu.npu_kronecker_quant.assert_called_once()
|
|
mock_torch_npu.npu_quant_matmul.assert_called_once()
|
|
self.assertTrue(
|
|
torch.allclose(output, mock_output + bias.to(self.params_dtype)))
|
|
self.assertEqual(output.shape, (batch_size, self.output_size))
|
|
|
|
@patch(
|
|
'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE',
|
|
10)
|
|
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
|
def test_apply_large_batch(self, mock_torch_npu):
|
|
"""Tests the apply method with a batch size larger than MAX_BATCH_SIZE."""
|
|
batch_size = 25
|
|
layer, x, m, n = self._prepare_apply_mocks_and_layer(batch_size)
|
|
mock_quant_x = torch.randint(0,
|
|
255, (batch_size, self.input_size // 8),
|
|
dtype=torch.int32)
|
|
mock_act_scale = torch.randn(batch_size, 1, dtype=torch.float32)
|
|
mock_torch_npu.npu_kronecker_quant.side_effect = [
|
|
(mock_quant_x[:10].view(10, m, n // 8), mock_act_scale[:10]),
|
|
(mock_quant_x[10:20].view(10, m, n // 8), mock_act_scale[10:20]),
|
|
(mock_quant_x[20:].view(5, m, n // 8), mock_act_scale[20:]),
|
|
]
|
|
mock_output = torch.randn(batch_size,
|
|
self.output_size,
|
|
dtype=self.params_dtype)
|
|
mock_torch_npu.npu_quant_matmul.return_value = mock_output
|
|
output = self.method.apply(layer, x, bias=None)
|
|
self.assertEqual(mock_torch_npu.npu_kronecker_quant.call_count, 3)
|
|
mock_torch_npu.npu_quant_matmul.assert_called_once()
|
|
self.assertTrue(torch.equal(output, mock_output))
|
|
self.assertEqual(output.shape, (batch_size, self.output_size))
|
|
|
|
def test_apply_dimension_mismatch_error(self):
|
|
"""Tests that apply raises ValueError on transform matrix dimension mismatch."""
|
|
layer, x, _, _ = self._prepare_apply_mocks_and_layer(16)
|
|
layer.left_trans = torch.randn(20, 20)
|
|
layer.right_trans = torch.randn(30, 30) # 20 * 30 != 768
|
|
with self.assertRaisesRegex(
|
|
ValueError, "FlatQuant transform matrices dimension mismatch"):
|
|
self.method.apply(layer, x)
|
|
|
|
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
|
|
def test_process_weights_after_loading(self, mock_pack_weights):
|
|
"""Tests weight processing after loading, without transpose."""
|
|
layer = nn.Module()
|
|
layer.weight = torch.randint(-8,
|
|
7, (self.output_size, self.input_size),
|
|
dtype=torch.int8)
|
|
layer.weight_scale = torch.randn(self.output_size,
|
|
1,
|
|
dtype=torch.bfloat16)
|
|
layer.weight_offset = torch.randn(self.output_size,
|
|
1,
|
|
dtype=torch.bfloat16)
|
|
layer.left_trans = torch.randn(24, 24)
|
|
layer.right_trans = torch.randn(32, 32)
|
|
layer.clip_ratio = torch.tensor([0.9])
|
|
mock_packed = torch.randint(0,
|
|
100,
|
|
(self.output_size, self.input_size // 8),
|
|
dtype=torch.int32)
|
|
mock_pack_weights.return_value = mock_packed
|
|
self.method.transpose_weight = False
|
|
self.method.process_weights_after_loading(layer)
|
|
mock_pack_weights.assert_called_once()
|
|
self.assertFalse(hasattr(layer, 'weight'))
|
|
self.assertTrue(hasattr(layer, 'weight_packed'))
|
|
self.assertTrue(torch.equal(layer.weight_packed.data, mock_packed))
|
|
self.assertEqual(layer.weight_scale.dtype, torch.float32)
|
|
self.assertEqual(layer.weight_offset.dtype, torch.float32)
|
|
self.assertEqual(layer.clip_ratio.dtype, torch.float32)
|
|
self.assertTrue(layer.aclnn_clip_ratio - 0.9 < 0.01)
|
|
self.assertEqual(layer.left_trans.shape, (24, 24))
|
|
self.assertTrue(layer.left_trans.is_contiguous())
|
|
|
|
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
|
|
def test_process_weights_after_loading_with_transpose(
|
|
self, mock_pack_weights):
|
|
"""Tests weight processing after loading, with transpose."""
|
|
layer = nn.Module()
|
|
layer.weight = torch.randint(-8,
|
|
7, (self.output_size, self.input_size),
|
|
dtype=torch.int8)
|
|
layer.weight_scale = torch.randn(self.output_size,
|
|
1,
|
|
dtype=torch.bfloat16)
|
|
layer.weight_offset = torch.randn(self.output_size,
|
|
1,
|
|
dtype=torch.bfloat16)
|
|
layer.left_trans = torch.randn(24, 24)
|
|
layer.right_trans = torch.randn(32, 32)
|
|
layer.clip_ratio = torch.tensor([0.9])
|
|
mock_packed = torch.randint(0,
|
|
100,
|
|
(self.output_size, self.input_size // 8),
|
|
dtype=torch.int32)
|
|
mock_pack_weights.return_value = mock_packed
|
|
self.method.transpose_weight = True
|
|
self.method.process_weights_after_loading(layer)
|
|
self.assertTrue(hasattr(layer, 'weight_packed'))
|
|
self.assertEqual(layer.weight_packed.shape,
|
|
(self.input_size // 8, self.output_size))
|
|
self.assertTrue(layer.weight_packed.is_contiguous())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(argv=['first-arg-is-ignored'], exit=False)
|