mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
[Refactor] Clean up w4a4_flatquant_dynamic implementation (#3440)
Cleans up the initial implementation of `w4a4_flatquant_dynamic` for better readability and maintainability. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
|
||||
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
|
||||
pack_int4_to_int32, pack_int4_weights)
|
||||
pack_int4_weights)
|
||||
|
||||
|
||||
class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
@ -33,25 +33,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(get_decompose_dim(100), (10, 10))
|
||||
self.assertEqual(get_decompose_dim(99), (9, 11))
|
||||
|
||||
def test_pack_int4_to_int32(self):
|
||||
"""
|
||||
Tests manual packing of an int4 tensor into an int32 tensor.
|
||||
"""
|
||||
int4_tensor = torch.arange(-8, 8, dtype=torch.int8).view(2, 8)
|
||||
expected_packed = torch.tensor([[1985229328], [-19088744]],
|
||||
dtype=torch.int32)
|
||||
packed_tensor = pack_int4_to_int32(int4_tensor)
|
||||
self.assertTrue(torch.equal(packed_tensor, expected_packed))
|
||||
|
||||
def test_pack_int4_to_int32_value_error(self):
|
||||
"""
|
||||
Tests that pack_int4_to_int32 raises ValueError for invalid input shapes.
|
||||
"""
|
||||
invalid_tensor = torch.zeros((1, 7), dtype=torch.int8)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The last dimension must be a multiple of 8."):
|
||||
pack_int4_to_int32(invalid_tensor)
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
|
||||
"""
|
||||
@ -71,23 +52,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
mock_torch_npu.npu_convert_weight_to_int4pack.assert_called_once()
|
||||
self.assertTrue(torch.equal(result, mock_packed_tensor))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_to_int32')
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
def test_pack_int4_weights_fallback(self, mock_torch_npu,
|
||||
mock_pack_manual):
|
||||
"""
|
||||
Tests the fallback mechanism when the NPU kernel fails.
|
||||
"""
|
||||
with patch('torch.Tensor.npu',
|
||||
side_effect=Exception("NPU not available")):
|
||||
weight_tensor = torch.randn(self.output_size, self.input_size)
|
||||
mock_pack_manual.return_value = "fallback success"
|
||||
result = pack_int4_weights(weight_tensor)
|
||||
mock_torch_npu.npu_convert_weight_to_int4pack.assert_not_called()
|
||||
mock_pack_manual.assert_called_once_with(weight_tensor)
|
||||
self.assertEqual(result, "fallback success")
|
||||
|
||||
## Test AscendW4A4FlatQuantDynamicLinearMethod Class
|
||||
## --------------------------------------------------
|
||||
|
||||
@ -101,8 +65,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(params["weight"].dtype, torch.int8)
|
||||
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.input_size,
|
||||
self.input_size)
|
||||
self.assertEqual(AscendW4A4FlatQuantDynamicLinearMethod.output_size,
|
||||
self.output_size)
|
||||
|
||||
def test_get_weight_value_error(self):
|
||||
"""Tests that get_weight raises ValueError for invalid input_size."""
|
||||
|
@ -15,61 +15,20 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import math
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
KRONECKER_QUANT_MAX_BATCH_SIZE = 8192
|
||||
|
||||
|
||||
def pack_int4_to_int32(int4_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Packs a tensor of 4-bit integers into a tensor of 32-bit integers.
|
||||
|
||||
This function serves as a manual, device-agnostic fallback when a more
|
||||
optimized hardware-specific kernel (like for an NPU) is not available.
|
||||
It processes the tensor along its last dimension.
|
||||
|
||||
Args:
|
||||
int4_tensor: A tensor with a dtype that can be represented as int4.
|
||||
The size of its last dimension must be a multiple of 8.
|
||||
|
||||
Returns:
|
||||
A new tensor of dtype torch.int32 where every 8 values from the
|
||||
original tensor's last dimension are packed into a single int32 value.
|
||||
"""
|
||||
if int4_tensor.shape[-1] % 8 != 0:
|
||||
raise ValueError("The last dimension must be a multiple of 8.")
|
||||
int4_clamped = torch.clamp(int4_tensor, -8, 7)
|
||||
uint4_tensor = int4_clamped.to(torch.uint8) + 8
|
||||
original_shape = uint4_tensor.shape
|
||||
packed_shape = list(original_shape[:-1]) + [original_shape[-1] // 8]
|
||||
uint4_reshaped = uint4_tensor.view(*original_shape[:-1], -1, 8)
|
||||
packed_tensor = torch.zeros(*packed_shape,
|
||||
dtype=torch.int32,
|
||||
device=uint4_tensor.device)
|
||||
for i in range(8):
|
||||
packed_tensor += (uint4_reshaped[..., i].to(torch.int32) << (i * 4))
|
||||
return packed_tensor
|
||||
KRONECKER_QUANT_MAX_BATCH_SIZE = 32768
|
||||
|
||||
|
||||
def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Packs a weight tensor from int4 to int32, using an NPU-accelerated
|
||||
kernel if available, otherwise falling back to a manual implementation.
|
||||
"""
|
||||
try:
|
||||
original_device = weight_tensor.device
|
||||
weight_tensor_npu = weight_tensor.npu()
|
||||
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
|
||||
return weight_int4_packed.to(original_device)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: NPU kernel 'npu_convert_weight_to_int4pack' is not available. "
|
||||
f"Falling back to a manual packing implementation. Error: {e}")
|
||||
return pack_int4_to_int32(weight_tensor)
|
||||
original_device = weight_tensor.device
|
||||
weight_tensor_npu = weight_tensor.npu()
|
||||
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
|
||||
return weight_int4_packed.to(original_device)
|
||||
|
||||
|
||||
def get_decompose_dim(n):
|
||||
@ -85,6 +44,37 @@ def get_decompose_dim(n):
|
||||
return a - b, a + b
|
||||
|
||||
|
||||
# TODO: This function is a temporary workaround for the npu_kronecker_quant operator,
|
||||
# which has a limitation on the maximum batch size (dim0). This wrapper should be
|
||||
# removed once the operator supports larger inputs natively.
|
||||
def batched_kronecker_quant(
|
||||
x: torch.Tensor,
|
||||
left_trans: torch.Tensor,
|
||||
right_trans: torch.Tensor,
|
||||
clip_ratio: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_tokens = x.shape[0]
|
||||
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
|
||||
return torch_npu.npu_kronecker_quant(x,
|
||||
left_trans,
|
||||
right_trans,
|
||||
clip_ratio=clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0)
|
||||
processed_chunks = [
|
||||
torch_npu.npu_kronecker_quant(chunk,
|
||||
left_trans,
|
||||
right_trans,
|
||||
clip_ratio=clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
for chunk in x_chunks
|
||||
]
|
||||
quantized_list, scale_list = zip(*processed_chunks)
|
||||
x_quantized_int4 = torch.cat(quantized_list, dim=0)
|
||||
activation_scale = torch.cat(scale_list, dim=0)
|
||||
return x_quantized_int4, activation_scale
|
||||
|
||||
|
||||
class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A4_FLATQUANT_DYNAMIC.
|
||||
|
||||
@ -94,7 +84,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights
|
||||
"""
|
||||
input_size = 0
|
||||
output_size = 0
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = False
|
||||
@ -108,7 +97,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
f"input_size ({input_size}) must be divisible by 8 for int4 packing"
|
||||
)
|
||||
AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size
|
||||
AscendW4A4FlatQuantDynamicLinearMethod.output_size = output_size
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
@ -156,42 +144,21 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
original_dtype = x.dtype
|
||||
input_shape = x.shape
|
||||
in_features = input_shape[-1]
|
||||
M = layer.left_trans.shape[0]
|
||||
N = layer.right_trans.shape[0]
|
||||
if M * N != in_features:
|
||||
left_dim = layer.left_trans.shape[0]
|
||||
right_dim = layer.right_trans.shape[0]
|
||||
if left_dim * right_dim != in_features:
|
||||
raise ValueError(
|
||||
f"FlatQuant transform matrices dimension mismatch: M({M}) * N({N}) != in_features({in_features})"
|
||||
f"FlatQuant transform matrices dimension mismatch: "
|
||||
f"left_dim({left_dim}) * right_dim({right_dim}) != in_features({in_features})"
|
||||
)
|
||||
left_trans_matched = layer.left_trans.to(original_dtype)
|
||||
right_trans_matched = layer.right_trans.to(original_dtype)
|
||||
x_reshaped = x.view(-1, M, N)
|
||||
batch_tokens = x_reshaped.shape[0]
|
||||
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
|
||||
x_quantized_int4, activation_scale = torch_npu.npu_kronecker_quant(
|
||||
x_reshaped,
|
||||
left_trans_matched,
|
||||
right_trans_matched,
|
||||
clip_ratio=layer.aclnn_clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
else:
|
||||
x_quantized_int4_list = []
|
||||
activation_scale_list = []
|
||||
for start_idx in range(0, batch_tokens,
|
||||
KRONECKER_QUANT_MAX_BATCH_SIZE):
|
||||
end_idx = min(start_idx + KRONECKER_QUANT_MAX_BATCH_SIZE,
|
||||
batch_tokens)
|
||||
x_batch = x_reshaped[start_idx:end_idx]
|
||||
x_quantized_batch, activation_scale_batch = torch_npu.npu_kronecker_quant(
|
||||
x_batch,
|
||||
left_trans_matched,
|
||||
right_trans_matched,
|
||||
clip_ratio=layer.aclnn_clip_ratio,
|
||||
dst_dtype=torch.int32)
|
||||
x_quantized_int4_list.append(x_quantized_batch)
|
||||
activation_scale_list.append(activation_scale_batch)
|
||||
x_quantized_int4 = torch.cat(x_quantized_int4_list, dim=0)
|
||||
activation_scale = torch.cat(activation_scale_list, dim=0)
|
||||
x_quantized_reshaped = x_quantized_int4.view(-1, M * N // 8)
|
||||
x_reshaped = x.view(-1, left_dim, right_dim)
|
||||
x_quantized_int4, activation_scale = batched_kronecker_quant(
|
||||
x_reshaped, left_trans_matched, right_trans_matched,
|
||||
layer.aclnn_clip_ratio)
|
||||
x_quantized_reshaped = x_quantized_int4.view(-1,
|
||||
left_dim * right_dim // 8)
|
||||
pertoken_scale = activation_scale.view(-1).to(torch.float32)
|
||||
output = torch_npu.npu_quant_matmul(x_quantized_reshaped,
|
||||
layer.weight_packed.t(),
|
||||
|
Reference in New Issue
Block a user