mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
support qwen25 vl w8a8 quantization (#2778)
### What this PR does / why we need it?
support qwen25 vl w8a8 quantization
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
62f66be1f7
---------
Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>
Co-authored-by: lijiaojiao <lijiaojiao990304@163.com>
This commit is contained in:
@ -353,6 +353,46 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
|
||||
assert cos_new.shape == (1, 32, 1, 2)
|
||||
|
||||
def test_pad_qkv_bias(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_bias(torch.rand((300)))
|
||||
assert res.shape[0] == 384
|
||||
|
||||
def test_pad_qkv_weight(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (384, 300)
|
||||
|
||||
def test_pad_proj_weight(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_proj_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (300, 384)
|
||||
|
||||
def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1)))
|
||||
assert res.shape == (384, 1)
|
||||
|
||||
def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300)))
|
||||
assert res.shape[0] == 384
|
||||
|
||||
def test_forward(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
|
@ -291,6 +291,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
self.hidden_size, -1)
|
||||
return out_weight
|
||||
|
||||
def pad_qkv_weight_scale_offset(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head, 1)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head, :]
|
||||
data2 = reshaped_data[:, :, self.
|
||||
half_origin_hidden_size_per_attention_head:, :]
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
|
||||
0, 0, 0))
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1, 1)
|
||||
return res
|
||||
|
||||
def pad_qkv_deq_scale_quant_bias(self, data):
|
||||
reshaped_data = data.reshape(
|
||||
-1, 3, self.origin_hidden_size_per_attention_head)
|
||||
data1 = reshaped_data[:, :, :self.
|
||||
half_origin_hidden_size_per_attention_head]
|
||||
data2 = reshaped_data[:, :,
|
||||
self.half_origin_hidden_size_per_attention_head:]
|
||||
|
||||
data1_paded = torch.nn.functional.pad(
|
||||
data1, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
data2_paded = torch.nn.functional.pad(
|
||||
data2, (0, self.half_pad_hidden_size_per_attention_head))
|
||||
|
||||
res = torch.cat([data1_paded, data2_paded], dim=2)
|
||||
res = res.reshape(-1)
|
||||
return res
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
|
||||
@ -318,11 +352,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if ("attn.proj.weight" in name) and self.enable_pad:
|
||||
if ("attn.proj.weight_scale" in name or
|
||||
"attn.proj.weight_offset" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.proj.deq_scale" in name
|
||||
or "attn.proj.quant_bias" in name) and self.enable_pad:
|
||||
continue
|
||||
elif ("attn.qkv.weight_scale" in name
|
||||
or "attn.qkv.weight_offset" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight_scale_offset(param.data)
|
||||
elif ("attn.qkv.deq_scale" in name
|
||||
or "attn.qkv.quant_bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
|
||||
elif ("attn.proj.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_proj_weight(param.data)
|
||||
if ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
elif ("attn.qkv.weight" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_weight(param.data)
|
||||
if ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
elif ("attn.qkv.bias" in name) and self.enable_pad:
|
||||
param.data = self.pad_qkv_bias(param.data)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@ -445,6 +491,17 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
||||
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
|
||||
class AscendQwen2_5_VLForConditionalGeneration(
|
||||
Qwen2_5_VLForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
@ -53,6 +53,7 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -89,6 +90,8 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split('.', 1)[-1]
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
|
Reference in New Issue
Block a user