mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Misc] Fix get_min_capability
(#5971)
This commit is contained in:
@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
|
|||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
return [torch.half]
|
return [torch.half]
|
||||||
|
|
||||||
def get_min_capability(self) -> int:
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
# The AWQ kernel only supports Turing or newer GPUs.
|
# The AWQ kernel only supports Turing or newer GPUs.
|
||||||
return 75
|
return 75
|
||||||
|
|
||||||
|
@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
|
|||||||
"""List of supported activation dtypes."""
|
"""List of supported activation dtypes."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_min_capability(self) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
"""Minimum GPU capability to support the quantization method.
|
"""Minimum GPU capability to support the quantization method.
|
||||||
|
|
||||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||||
|
@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
return [torch.float32, torch.float16, torch.bfloat16]
|
return [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(self) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 70
|
return 70
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
return [torch.float16, torch.bfloat16]
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
# Need to figure it out
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 60
|
return 75
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return "compressed_tensors"
|
return "compressed_tensors"
|
||||||
@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _check_gptq_and_marlin_can_run(self):
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
if capability < 80:
|
||||||
|
raise RuntimeError("The quantization config is not supported for ",
|
||||||
|
"the current GPU. Minimum capability: 80. ",
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
|
||||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
|
self._check_gptq_and_marlin_can_run()
|
||||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||||
return CompressedTensorsW4A16Sparse24(
|
return CompressedTensorsW4A16Sparse24(
|
||||||
|
@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
|
|||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
return [torch.half]
|
return [torch.half]
|
||||||
|
|
||||||
def get_min_capability(self) -> int:
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
return 70
|
return 70
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Reference in New Issue
Block a user