[Misc] Fix get_min_capability (#5971)

This commit is contained in:
Dipika Sikka
2024-06-30 16:15:16 -04:00
committed by GitHub
parent deacb7ec44
commit 7836fdcc11
5 changed files with 17 additions and 6 deletions

View File

@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

View File

@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
@abstractmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.

View File

@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
return 70
@staticmethod

View File

@ -33,10 +33,9 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
# Need to figure it out
@classmethod
def get_min_capability(cls) -> int:
return 60
return 75
def get_name(self) -> str:
return "compressed_tensors"
@ -84,6 +83,14 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return []
def _check_gptq_and_marlin_can_run(self):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
@ -126,6 +133,7 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(

View File

@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
return 70
@staticmethod