support qwen25 vl w8a8 quantization (#2778)

### What this PR does / why we need it?
support qwen25 vl w8a8 quantization
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
62f66be1f7

---------

Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>
Co-authored-by: lijiaojiao <lijiaojiao990304@163.com>
This commit is contained in:
6lazijiamo
2025-09-11 16:40:51 +08:00
committed by GitHub
parent 2b9269b581
commit bd3dedea61
3 changed files with 103 additions and 3 deletions

View File

@ -353,6 +353,46 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
assert cos_new.shape == (1, 32, 1, 2)
def test_pad_qkv_bias(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
res = attention.pad_qkv_bias(torch.rand((300)))
assert res.shape[0] == 384
def test_pad_qkv_weight(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
res = attention.pad_qkv_weight(torch.rand((300, 300)))
assert res.shape == (384, 300)
def test_pad_proj_weight(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
res = attention.pad_proj_weight(torch.rand((300, 300)))
assert res.shape == (300, 384)
def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1)))
assert res.shape == (384, 1)
def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300)))
assert res.shape[0] == 384
def test_forward(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")

View File

@ -291,6 +291,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
self.hidden_size, -1)
return out_weight
def pad_qkv_weight_scale_offset(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, 1)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head, :]
data2 = reshaped_data[:, :, self.
half_origin_hidden_size_per_attention_head:, :]
data1_paded = torch.nn.functional.pad(
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
data2_paded = torch.nn.functional.pad(
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1, 1)
return res
def pad_qkv_deq_scale_quant_bias(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head]
data2 = reshaped_data[:, :,
self.half_origin_hidden_size_per_attention_head:]
data1_paded = torch.nn.functional.pad(
data1, (0, self.half_pad_hidden_size_per_attention_head))
data2_paded = torch.nn.functional.pad(
data2, (0, self.half_pad_hidden_size_per_attention_head))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1)
return res
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
@ -318,11 +352,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if ("attn.proj.weight" in name) and self.enable_pad:
if ("attn.proj.weight_scale" in name or
"attn.proj.weight_offset" in name) and self.enable_pad:
continue
elif ("attn.proj.deq_scale" in name
or "attn.proj.quant_bias" in name) and self.enable_pad:
continue
elif ("attn.qkv.weight_scale" in name
or "attn.qkv.weight_offset" in name) and self.enable_pad:
param.data = self.pad_qkv_weight_scale_offset(param.data)
elif ("attn.qkv.deq_scale" in name
or "attn.qkv.quant_bias" in name) and self.enable_pad:
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
elif ("attn.proj.weight" in name) and self.enable_pad:
param.data = self.pad_proj_weight(param.data)
if ("attn.qkv.weight" in name) and self.enable_pad:
elif ("attn.qkv.weight" in name) and self.enable_pad:
param.data = self.pad_qkv_weight(param.data)
if ("attn.qkv.bias" in name) and self.enable_pad:
elif ("attn.qkv.bias" in name) and self.enable_pad:
param.data = self.pad_qkv_bias(param.data)
loaded_params.add(name)
return loaded_params
@ -445,6 +491,17 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class AscendQwen2_5_VLForConditionalGeneration(
Qwen2_5_VLForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

View File

@ -53,6 +53,7 @@ class AscendQuantConfig(QuantizationConfig):
"""
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
def __repr__(self) -> str:
@ -89,6 +90,8 @@ class AscendQuantConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):