mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Support FP8 kv cache scales from compressed-tensors (#6528)
This commit is contained in:
@ -150,3 +150,10 @@ def test_compressed_tensors_fp8(vllm_runner):
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_kv_cache(vllm_runner):
|
||||
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
||||
output = llm.generate_greedy("Hello world!", max_tokens=20)
|
||||
assert output
|
||||
|
@ -9,7 +9,7 @@ from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -59,19 +59,18 @@ class Attention(nn.Module):
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None:
|
||||
assert isinstance(quant_method, Fp8KVCacheMethod)
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if "fp8" in self.kv_cache_dtype:
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# When FP8 quantization is enabled, we make a parameter
|
||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||
# The k/v_scale will then be converted back to
|
||||
# self._kv_scale in a native float32 value after weight loading
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsUnquantized,
|
||||
@ -15,18 +15,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
QuantizationType, find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
|
||||
quant_format: str):
|
||||
def __init__(self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None):
|
||||
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
@ -50,9 +55,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["CompressedTensorsLinearMethod"]:
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@ -85,7 +93,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format)
|
||||
quant_format=quant_format,
|
||||
kv_cache_scheme=config.get("kv_cache_scheme"))
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -309,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from compressed-tensors
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_scheme: the compressed-tensors kv cache scheme
|
||||
"""
|
||||
if kv_cache_scheme is None:
|
||||
return
|
||||
|
||||
type_ = kv_cache_scheme.get("type")
|
||||
num_bits = kv_cache_scheme.get("num_bits")
|
||||
|
||||
if type_ != "float" and num_bits != 8:
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
"num_bits=8, type=float, however "
|
||||
f"received num_bits={num_bits}, type={type_}")
|
||||
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}")
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
if not is_symmetric:
|
||||
raise NotImplementedError(
|
||||
"Only support symmetric scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}")
|
||||
|
@ -209,6 +209,23 @@ def _find_first_match(value: str,
|
||||
return None
|
||||
|
||||
|
||||
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
|
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -400,64 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_group=topk_group)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""Create "weight" (aka k_scale and v_scale) for an attention layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the QuantizeMethodBase factory.
|
||||
"""
|
||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||
# If the k/v_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(
|
||||
v_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale = k_scale
|
||||
layer._v_scale = v_scale
|
||||
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
print_warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
super().__init__(quant_config)
|
||||
|
78
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
78
vllm/model_executor/layers/quantization/kv_cache.py
Normal file
@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
- dequantize k/v_cache entries before fetching them from the cache
|
||||
|
||||
:param quant_config: the appropriate QuantizationConfig
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuantizationConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Create "weight" (aka k_scale and v_scale) for an attention layer.
|
||||
"""
|
||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||
# If the k/v_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = torch.nn.Parameter(torch.tensor(1.0),
|
||||
requires_grad=False)
|
||||
v_scale = torch.nn.Parameter(torch.tensor(1.0),
|
||||
requires_grad=False)
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(
|
||||
v_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale = k_scale
|
||||
layer._v_scale = v_scale
|
||||
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
and "e5m2" not in layer.kv_cache_dtype):
|
||||
print_warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint.")
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
@ -39,6 +39,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -467,6 +469,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user