[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (#15734)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
rasmith
2025-04-25 02:45:02 -05:00
committed by GitHub
parent 6aae216b4e
commit a41351f363
8 changed files with 105 additions and 20 deletions

View File

@ -237,6 +237,7 @@ class AttentionLayer(Protocol):
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
def forward(
self,

View File

@ -766,6 +766,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale,
layer._prob_scale) if use_fp8_scales else None
self.triton_attn_func(
query,
key,
@ -779,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:

View File

@ -90,6 +90,7 @@ class Attention(nn.Module):
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)

View File

@ -3767,6 +3767,17 @@ class VllmConfig:
return quant_config
return None
@staticmethod
def get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
import copy
# For some reason, the _ version of this modifies the model_config
# object, so using deepcopy to avoid this problem.
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
load_config)
def with_hf_config(
self,
hf_config: PretrainedConfig,

View File

@ -1368,6 +1368,23 @@ class EngineArgs:
recommend_to_remove=False)
return False
if current_platform.is_rocm():
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
load_config = self.create_load_config()
quantization_config = VllmConfig.get_quantization_config(
model_config, load_config)
if isinstance(quantization_config, Fp8Config):
_raise_or_fallback(feature_name="fp8 for ROCm",
recommend_to_remove=False)
return False
from vllm.model_executor.layers.quantization.quark.quark import (
QuarkConfig)
if isinstance(quantization_config, QuarkConfig
) and quantization_config.has_fp8_layer_weights():
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
recommend_to_remove=False)
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")

View File

@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

View File

@ -38,6 +38,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
@ -97,5 +100,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
if layer.q_scale > 0.0:
q_scale = layer.q_scale
if current_platform.is_fp8_fnuz():
q_scale *= 2
layer.calculate_kv_scales = False
else:
q_scale = 1.0
if layer.prob_scale > 0.0:
prob_scale = layer.prob_scale
if current_platform.is_fp8_fnuz():
prob_scale *= 2
else:
prob_scale = 1.0
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
if not is_singleton_float(q_scale) or not is_singleton_float(
prob_scale):
raise ValueError("Only support per-tensor scaling factor"
"for fp8-quantized Q/prob")
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._prob_scale.copy_(prob_scale)
if q_scale == 1.0 or prob_scale == 1.0:
logger.warning_once(
f"Using Q scale {q_scale} and prob scale {prob_scale} "
"with fp8 attention. This may cause accuracy issues. "
"Please make sure Q/prob scaling factors are "
"available in the fp8 checkpoint.")
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast
import torch
@ -125,6 +124,13 @@ class QuarkConfig(QuantizationConfig):
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(Dict[str, Any],
layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
@ -289,29 +295,30 @@ class QuarkConfig(QuantizationConfig):
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None
kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None
def has_fp8_layer_weights(self):
layer_quant_config = self.quant_config.get("layer_quant_config")
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
return any([
'fp8' in cast(
str,
to_dict(
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
"weight")).get("dtype"))
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
])
class QuarkLinearMethod(LinearMethodBase):