mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Misc] Fix get_min_capability
(#5971)
This commit is contained in:
@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
|
@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(self) -> int:
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
|
@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(self) -> int:
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
|
@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
# Need to figure it out
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
return 75
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "compressed_tensors"
|
||||
@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
def _check_gptq_and_marlin_can_run(self):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < 80:
|
||||
raise RuntimeError("The quantization config is not supported for ",
|
||||
"the current GPU. Minimum capability: 80. ",
|
||||
f"Current capability: {capability}.")
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
self._check_gptq_and_marlin_can_run()
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
|
@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user