mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Automatically Detect SparseML models (#5119)
This commit is contained in:
@ -156,6 +156,17 @@ class ModelConfig:
|
||||
self.embedding_mode = any(
|
||||
ModelRegistry.is_embedding_model(arch) for arch in architectures)
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# SparseML uses a "compression_config" with a "quantization_config".
|
||||
compression_cfg = getattr(self.hf_config, "compression_config",
|
||||
None)
|
||||
if compression_cfg is not None:
|
||||
quant_cfg = compression_cfg.get("quantization_config", None)
|
||||
|
||||
return quant_cfg
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||
@ -163,12 +174,13 @@ class ModelConfig:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for name, method in QUANTIZATION_METHODS.items():
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization)
|
||||
if quantization_override:
|
||||
|
Reference in New Issue
Block a user