mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (#15734)
Signed-off-by: Randall Smith <Randall.Smith@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@ -237,6 +237,7 @@ class AttentionLayer(Protocol):
|
||||
_v_scale: torch.Tensor
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -766,6 +766,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
use_fp8_scales = (layer._q_scale and layer._k_scale
|
||||
and layer._v_scale and layer._prob_scale
|
||||
and self.kv_cache_dtype == "fp8")
|
||||
full_scales = (
|
||||
layer._q_scale, layer._k_scale, layer._v_scale,
|
||||
layer._prob_scale) if use_fp8_scales else None
|
||||
self.triton_attn_func(
|
||||
query,
|
||||
key,
|
||||
@ -779,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
full_scales,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
|
@ -90,6 +90,7 @@ class Attention(nn.Module):
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep the float32 versions of k/v_scale for attention
|
||||
# backends that don't support tensors (Flashinfer)
|
||||
|
@ -3767,6 +3767,17 @@ class VllmConfig:
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
import copy
|
||||
|
||||
# For some reason, the _ version of this modifies the model_config
|
||||
# object, so using deepcopy to avoid this problem.
|
||||
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
|
||||
load_config)
|
||||
|
||||
def with_hf_config(
|
||||
self,
|
||||
hf_config: PretrainedConfig,
|
||||
|
@ -1368,6 +1368,23 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
load_config = self.create_load_config()
|
||||
quantization_config = VllmConfig.get_quantization_config(
|
||||
model_config, load_config)
|
||||
if isinstance(quantization_config, Fp8Config):
|
||||
_raise_or_fallback(feature_name="fp8 for ROCm",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
from vllm.model_executor.layers.quantization.quark.quark import (
|
||||
QuarkConfig)
|
||||
|
||||
if isinstance(quantization_config, QuarkConfig
|
||||
) and quantization_config.has_fp8_layer_weights():
|
||||
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
|
||||
recommend_to_remove=False)
|
||||
|
||||
# No Fp8 KV cache so far.
|
||||
if self.kv_cache_dtype != "auto":
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
|
@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
# Initialize P = softmax(QK^T) scales
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(
|
||||
@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
|
||||
if layer.q_scale > 0.0:
|
||||
q_scale = layer.q_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
q_scale *= 2
|
||||
layer.calculate_kv_scales = False
|
||||
else:
|
||||
q_scale = 1.0
|
||||
if layer.prob_scale > 0.0:
|
||||
prob_scale = layer.prob_scale
|
||||
if current_platform.is_fp8_fnuz():
|
||||
prob_scale *= 2
|
||||
else:
|
||||
prob_scale = 1.0
|
||||
|
||||
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
|
||||
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
|
||||
if not is_singleton_float(q_scale) or not is_singleton_float(
|
||||
prob_scale):
|
||||
raise ValueError("Only support per-tensor scaling factor"
|
||||
"for fp8-quantized Q/prob")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if q_scale == 1.0 or prob_scale == 1.0:
|
||||
logger.warning_once(
|
||||
f"Using Q scale {q_scale} and prob scale {prob_scale} "
|
||||
"with fp8 attention. This may cause accuracy issues. "
|
||||
"Please make sure Q/prob scaling factors are "
|
||||
"available in the fp8 checkpoint.")
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.prob_scale
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(Dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
return cls(quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
@ -289,29 +295,30 @@ class QuarkConfig(QuantizationConfig):
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
|
||||
return None
|
||||
|
||||
kv_proj_names = [
|
||||
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
|
||||
]
|
||||
if name.endswith(".output_scale"):
|
||||
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
|
||||
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
|
||||
return name.replace(kv_output_scale_name, ".attn.k_scale")
|
||||
|
||||
elif len(kv_proj_names) == 2:
|
||||
for kv_proj_name in kv_proj_names:
|
||||
if kv_proj_name in name and kv_proj_name == "k_proj":
|
||||
return name.replace(".k_proj.output_scale",
|
||||
".attn.k_scale")
|
||||
elif kv_proj_name in name and kv_proj_name == "v_proj":
|
||||
return name.replace(".v_proj.output_scale",
|
||||
".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_fp8_layer_weights(self):
|
||||
layer_quant_config = self.quant_config.get("layer_quant_config")
|
||||
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
|
||||
return any([
|
||||
'fp8' in cast(
|
||||
str,
|
||||
to_dict(
|
||||
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
|
||||
"weight")).get("dtype"))
|
||||
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
|
||||
])
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
|
Reference in New Issue
Block a user