From bd3dedea6123c9c8c19fe83b6e05716f63b1285d Mon Sep 17 00:00:00 2001 From: 6lazijiamo <56385650+wenba0@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:40:51 +0800 Subject: [PATCH] support qwen25 vl w8a8 quantization (#2778) ### What this PR does / why we need it? support qwen25 vl w8a8 quantization ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/62f66be1f74378e5a22e266ad161023c324cf4f8 --------- Signed-off-by: lijiaojiao Co-authored-by: lijiaojiao --- tests/ut/models/test_qwen2_5_vl.py | 40 +++++++++++++++ vllm_ascend/models/qwen2_5_vl.py | 63 ++++++++++++++++++++++-- vllm_ascend/quantization/quant_config.py | 3 ++ 3 files changed, 103 insertions(+), 3 deletions(-) diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py index d33f33761..e5ce511ea 100644 --- a/tests/ut/models/test_qwen2_5_vl.py +++ b/tests/ut/models/test_qwen2_5_vl.py @@ -353,6 +353,46 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase): cos_new, _ = vision_transformer.cal_cos_sin(self.input_data) assert cos_new.shape == (1, 32, 1, 2) + def test_pad_qkv_bias(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_bias(torch.rand((300))) + assert res.shape[0] == 384 + + def test_pad_qkv_weight(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_weight(torch.rand((300, 300))) + assert res.shape == (384, 300) + + def test_pad_proj_weight(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_proj_weight(torch.rand((300, 300))) + assert res.shape == (300, 384) + + def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1))) + assert res.shape == (384, 1) + + def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture): + attention = self.init_vision_transformer(mocker) + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300))) + assert res.shape[0] == 384 + def test_forward(self, mocker: MockerFixture): vision_transformer = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 31ad2603a..b15946a11 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -291,6 +291,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): self.hidden_size, -1) return out_weight + def pad_qkv_weight_scale_offset(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head, 1) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head, :] + data2 = reshaped_data[:, :, self. + half_origin_hidden_size_per_attention_head:, :] + data1_paded = torch.nn.functional.pad( + data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + data2_paded = torch.nn.functional.pad( + data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, + 0, 0, 0)) + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1, 1) + return res + + def pad_qkv_deq_scale_quant_bias(self, data): + reshaped_data = data.reshape( + -1, 3, self.origin_hidden_size_per_attention_head) + data1 = reshaped_data[:, :, :self. + half_origin_hidden_size_per_attention_head] + data2 = reshaped_data[:, :, + self.half_origin_hidden_size_per_attention_head:] + + data1_paded = torch.nn.functional.pad( + data1, (0, self.half_pad_hidden_size_per_attention_head)) + data2_paded = torch.nn.functional.pad( + data2, (0, self.half_pad_hidden_size_per_attention_head)) + + res = torch.cat([data1_paded, data2_paded], dim=2) + res = res.reshape(-1) + return res + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ @@ -318,11 +352,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ("attn.proj.weight" in name) and self.enable_pad: + if ("attn.proj.weight_scale" in name or + "attn.proj.weight_offset" in name) and self.enable_pad: + continue + elif ("attn.proj.deq_scale" in name + or "attn.proj.quant_bias" in name) and self.enable_pad: + continue + elif ("attn.qkv.weight_scale" in name + or "attn.qkv.weight_offset" in name) and self.enable_pad: + param.data = self.pad_qkv_weight_scale_offset(param.data) + elif ("attn.qkv.deq_scale" in name + or "attn.qkv.quant_bias" in name) and self.enable_pad: + param.data = self.pad_qkv_deq_scale_quant_bias(param.data) + elif ("attn.proj.weight" in name) and self.enable_pad: param.data = self.pad_proj_weight(param.data) - if ("attn.qkv.weight" in name) and self.enable_pad: + elif ("attn.qkv.weight" in name) and self.enable_pad: param.data = self.pad_qkv_weight(param.data) - if ("attn.qkv.bias" in name) and self.enable_pad: + elif ("attn.qkv.bias" in name) and self.enable_pad: param.data = self.pad_qkv_bias(param.data) loaded_params.add(name) return loaded_params @@ -445,6 +491,17 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class AscendQwen2_5_VLForConditionalGeneration( Qwen2_5_VLForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 95cc02cff..fb644a13c 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -53,6 +53,7 @@ class AscendQuantConfig(QuantizationConfig): """ def __init__(self, quant_config: Dict[str, Any]): + super().__init__() self.quant_description = quant_config def __repr__(self) -> str: @@ -89,6 +90,8 @@ class AscendQuantConfig(QuantizationConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention + if prefix.startswith("language_model"): + prefix = prefix.split('.', 1)[-1] if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):