[ Misc ] Clean Up CompressedTensorsW8A8 (#6113)

This commit is contained in:
Robert Shaw
2024-07-03 18:50:08 -04:00
committed by GitHub
parent d9e98f42e4
commit 62963d129e
6 changed files with 44 additions and 95 deletions

View File

@ -9,8 +9,7 @@ import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
CompressedTensorsW8A8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)
@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
torch.float8_e4m3fn)
@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8

View File

@ -9,8 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
CompressedTensorsW8A8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
@ -150,12 +149,12 @@ class CompressedTensorsConfig(QuantizationConfig):
if self.quant_format == CompressionFormat.int_quantized.value:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor(
strategy=weight_quant.strategy)
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
is_static_input_scheme=True)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken(
strategy=weight_quant.strategy)
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
is_static_input_scheme=False)
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")

View File

@ -3,9 +3,6 @@ from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16_24 import ( # noqa: F401
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401

View File

@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
@ -12,8 +13,9 @@ from vllm.model_executor.utils import set_weight_attrs
class CompressedTensorsW8A8(CompressedTensorsScheme):
def __init__(self, strategy: str):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
@ -36,6 +38,10 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
layer.weight_scale = Parameter(weight_scale_channel,
requires_grad=False)
# transpose weights for cutlass.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
@ -75,3 +81,29 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
"output_dim": 0,
"weight_loader": weight_loader,
})
# INPUT SCALE
# Static quantization: load from disk.
if self.is_static_input_scheme:
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
# Dynamic quantization: set to None.
else:
layer.input_scale = None
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype)

View File

@ -1,33 +0,0 @@
from typing import Callable, List
import torch
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)
__all__ = ["CompressedTensorsW8A8DynamicToken"]
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
weight_scale, x.dtype)

View File

@ -1,47 +0,0 @@
from typing import Callable, List
import torch
from torch.nn import Parameter
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8StaticTensor"]
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
act_scale = layer.input_scale
# Input quantize
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
weight_scale, x.dtype)