mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ Misc ] Clean Up CompressedTensorsW8A8
(#6113)
This commit is contained in:
@ -9,8 +9,7 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationType)
|
||||
|
||||
@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(down_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
||||
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.scheme.is_static_input_scheme
|
||||
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
||||
assert not qkv_proj.scheme.is_static_input_scheme
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.weight.dtype is torch.int8
|
||||
|
||||
|
@ -9,8 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
find_first_name_or_class_match)
|
||||
@ -150,12 +149,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
if self.quant_format == CompressionFormat.int_quantized.value:
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8StaticTensor(
|
||||
strategy=weight_quant.strategy)
|
||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8DynamicToken(
|
||||
strategy=weight_quant.strategy)
|
||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
@ -3,9 +3,6 @@ from .compressed_tensors_unquantized import ( # noqa: F401
|
||||
CompressedTensorsUnquantized)
|
||||
from .compressed_tensors_w4a16_24 import ( # noqa: F401
|
||||
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
||||
CompressedTensorsW8A8DynamicToken)
|
||||
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
||||
CompressedTensorsW8A8StaticTensor)
|
||||
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
|
||||
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
|
||||
|
@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
@ -12,8 +13,9 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
class CompressedTensorsW8A8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
# Cutlass kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
@ -36,6 +38,10 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
|
||||
layer.weight_scale = Parameter(weight_scale_channel,
|
||||
requires_grad=False)
|
||||
|
||||
# transpose weights for cutlass.
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
@ -75,3 +81,29 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
})
|
||||
|
||||
# INPUT SCALE
|
||||
# Static quantization: load from disk.
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
set_weight_attrs(input_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
# Dynamic quantization: set to None.
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
|
||||
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype)
|
||||
|
@ -1,33 +0,0 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as custom_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
|
||||
CompressedTensorsW8A8)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
super().create_weights(
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
x_q, input_scales = custom_ops.scaled_int8_quant(x)
|
||||
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
|
||||
weight_scale, x.dtype)
|
@ -1,47 +0,0 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as custom_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
|
||||
CompressedTensorsW8A8)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8StaticTensor"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
super().create_weights(
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
set_weight_attrs(input_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
act_scale = layer.input_scale
|
||||
|
||||
# Input quantize
|
||||
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
||||
|
||||
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
|
||||
weight_scale, x.dtype)
|
Reference in New Issue
Block a user