mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[ Misc ] Support Fp8 via llm-compressor
(#6110)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.752
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.752
|
||||
limit: 250
|
||||
num_fewshot: 5
|
@ -1,4 +1,4 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.728
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.728
|
||||
limit: 250
|
||||
num_fewshot: 5
|
@ -1,2 +1,4 @@
|
||||
Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
|
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
|
||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
|
||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||
--batch_size $BATCH_SIZE
|
||||
|
@ -24,7 +24,8 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
|
||||
|
||||
def launch_lm_eval(eval_config):
|
||||
model_args = f"pretrained={eval_config['model_name']}," \
|
||||
f"tensor_parallel_size={TP_SIZE}"
|
||||
f"tensor_parallel_size={TP_SIZE}," \
|
||||
f"add_bos_token=true"
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
@ -9,7 +9,8 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationType)
|
||||
|
||||
@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(down_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.scheme.is_static_input_scheme
|
||||
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
|
||||
torch.float8_e4m3fn)
|
||||
expected_type = torch.int8
|
||||
|
||||
assert qkv_proj.weight.dtype is expected_type
|
||||
assert o_proj.weight.dtype is expected_type
|
||||
@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||
assert not qkv_proj.scheme.is_static_input_scheme
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.weight.dtype is torch.int8
|
||||
@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||
sampling_params = SamplingParams()
|
||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_fp8(vllm_runner):
|
||||
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
with vllm_runner(model_path) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
||||
assert qkv_proj.input_scale.dtype is torch.float32
|
||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||
# should be scalars after processing
|
||||
assert len(qkv_proj.input_scale.shape) == 0
|
||||
assert len(qkv_proj.weight_scale.shape) == 0
|
||||
|
||||
sampling_params = SamplingParams()
|
||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||
assert output
|
||||
|
@ -9,10 +9,11 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
find_first_name_or_class_match)
|
||||
QuantizationType, find_first_name_or_class_match)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -117,6 +118,40 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if not (weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT):
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_weight = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
if not (is_symmetric_weight and is_static_weight
|
||||
and is_per_tensor_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.dynamic:
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_symmetric_activation = input_quant.symmetric
|
||||
is_per_tensor_activation = (
|
||||
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
if not (is_symmetric_activation and is_per_tensor_activation):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
@ -147,13 +182,20 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size)
|
||||
|
||||
if self.quant_format == CompressionFormat.int_quantized.value:
|
||||
if (self.quant_format == CompressionFormat.int_quantized.value or
|
||||
self.quant_format == CompressionFormat.float_quantized.value):
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
input_dynamic=input_quant.dynamic)
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False)
|
||||
|
||||
raise NotImplementedError(
|
||||
@ -187,7 +229,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return layer.scheme.process_weights_after_loading(layer)
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
|
@ -1,8 +1,19 @@
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
||||
CompressedTensorsUnquantized)
|
||||
from .compressed_tensors_w4a16_24 import ( # noqa: F401
|
||||
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
|
||||
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16)
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsUnquantized",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||
]
|
||||
|
@ -1,109 +0,0 @@
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class CompressedTensorsW8A8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
# Cutlass kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), we convert to the per-channel case.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if (self.strategy == QuantizationStrategy.TENSOR
|
||||
and len(self.logical_widths) > 1):
|
||||
|
||||
# Load the N per-tensor scales into the channelwise buffer.
|
||||
weight_scale_channel = torch.empty(
|
||||
(sum(self.logical_widths), 1),
|
||||
dtype=torch.float32,
|
||||
device=layer.weight_scale.device)
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(self.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
|
||||
start = end
|
||||
|
||||
layer.weight_scale = Parameter(weight_scale_channel,
|
||||
requires_grad=False)
|
||||
|
||||
# transpose weights for cutlass.
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
shape: Union[Tuple[int], Tuple[int, int]]
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
shape = (sum(self.logical_widths), 1)
|
||||
else:
|
||||
shape = (len(self.logical_widths), )
|
||||
|
||||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
set_weight_attrs(weight_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"output_dim": 0,
|
||||
})
|
||||
else:
|
||||
set_weight_attrs(weight_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"needs_scalar_to_array": True,
|
||||
})
|
||||
|
||||
# WEIGHT
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
})
|
||||
|
||||
# INPUT SCALE
|
||||
# Static quantization: load from disk.
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
set_weight_attrs(input_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
# Dynamic quantization: set to None.
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
|
||||
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype)
|
@ -0,0 +1,87 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, input_dynamic: bool):
|
||||
self.input_dynamic = input_dynamic
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), we requantize with a single scale.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(max_w_scale,
|
||||
requires_grad=False)
|
||||
if self.input_dynamic:
|
||||
layer.input_scale = None
|
||||
else:
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
del params_dtype
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
})
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if not self.input_dynamic:
|
||||
input_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
@ -0,0 +1,85 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param,
|
||||
create_per_tensor_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(self.logical_widths) > 1
|
||||
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
||||
self.logical_widths)
|
||||
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
})
|
||||
|
||||
# WEIGHT SCALE
|
||||
layer_kwargs = {"weight_loader": weight_loader}
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
scale = create_per_channel_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
return apply_int8_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale)
|
@ -9,6 +9,7 @@ from torch.nn import Module
|
||||
class CompressionFormat(Enum):
|
||||
dense = "dense"
|
||||
sparse_bitmask = "sparse-bitmask"
|
||||
float_quantized = "float-quantized"
|
||||
int_quantized = "int-quantized"
|
||||
pack_quantized = "pack-quantized"
|
||||
marlin_24 = "marlin-24"
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@ -11,11 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
pack_fp8_to_int32)
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
@ -25,13 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@ -117,23 +110,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
scale_name: str,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(scale, {
|
||||
**extra_weight_attrs,
|
||||
"needs_scalar_to_array": True,
|
||||
})
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -147,7 +123,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.process_after_load = True
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
@ -173,144 +148,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
self._create_scale_param(
|
||||
scale_name="weight_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
self._create_scale_param(
|
||||
scale_name="input_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
||||
# fused dequantization
|
||||
if self.use_marlin:
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
return
|
||||
|
||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
scale=None)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.input_scale = None
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
else:
|
||||
# WEIGHT_SCALE / WEIGHT
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||
# weight scales are still set to the default since we initialize
|
||||
# N weight scales for N shards but we only load 1 weight scale
|
||||
# from disk in this case. As a result, we skip dequant -> requant
|
||||
# since we already have quantized QKV together.
|
||||
# Sample Model with fused checkpoint:
|
||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
unfused_module_in_checkpoint = (
|
||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
|
||||
|
||||
if unfused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
layer.weight[start:end, :], layer.weight_scale[idx])
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max())
|
||||
start = end
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# WEIGHT
|
||||
# Transpose weight for passing to torch._scaled_mm
|
||||
weight = layer.weight
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||
# Static: set to max of the input_scales (since they are equal).
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.input_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
@ -318,65 +199,22 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=layer.weight,
|
||||
b_scales=layer.weight_scale,
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
bias=bias)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
else:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
@ -399,8 +237,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
layer.process_after_load = True
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@ -465,9 +301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.a2_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
@ -531,7 +364,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
shard_size, :],
|
||||
layer.w13_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :] = per_tensor_quantize(
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
@ -596,23 +429,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||
"factor is available in the fp8 checkpoint.")
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
@ -11,20 +11,16 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM,
|
||||
GPTQ_MARLIN_TILE)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
# Permutations for Marlin scale shuffling
|
||||
def get_scale_perms(num_bits: int):
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
||||
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
||||
@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
@ -22,7 +32,92 @@ def is_marlin_supported():
|
||||
return capability[0] >= 8
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
num_bits = 8
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
scale_perm=marlin_scale_perm[num_bits],
|
||||
scale_perm_single=marlin_scale_perm_single[num_bits])
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (part_size_n //
|
||||
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
163
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
163
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
@ -0,0 +1,163 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
||||
|
||||
def create_per_tensor_scale_param(
|
||||
output_partition_sizes: List[int],
|
||||
**extra_weight_attrs,
|
||||
) -> Parameter:
|
||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {
|
||||
"needs_scalar_to_array": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
return scale
|
||||
|
||||
|
||||
def create_per_channel_scale_param(output_partition_sizes: List[int],
|
||||
**extra_weight_attrs) -> Parameter:
|
||||
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
|
||||
return scale
|
||||
|
||||
|
||||
def convert_to_channelwise(
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Create channelwise buffer
|
||||
weight_scale_channel = torch.empty((sum(logical_widths), 1),
|
||||
dtype=torch.float32,
|
||||
device=weight_scale.device)
|
||||
|
||||
# Expand each scale to match the size of each logical matrix.
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_scale_channel[start:end, :] = weight_scale[idx]
|
||||
start = end
|
||||
|
||||
return weight_scale_channel
|
||||
|
||||
|
||||
def requantize_with_max_scale(
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max()
|
||||
|
||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||
# weight scales are still set to the default since we initialize
|
||||
# N weight scales for N shards but we only load 1 weight scale
|
||||
# from disk in this case. Skip requantization in this case (since)
|
||||
# we already are quantized with the single scale.
|
||||
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
|
||||
torch.float8_e4m3fn).min)
|
||||
|
||||
# If unfused checkpoint, need requanize with the single scale.
|
||||
if unfused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
||||
weight_scale[idx])
|
||||
weight[start:end, :], _ = ops.scaled_fp8_quant(
|
||||
weight_dq, max_w_scale)
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
def apply_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_fp8_supported: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
if bias is None and cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(input,
|
||||
input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
|
||||
|
||||
def apply_int8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if bias is not None:
|
||||
raise NotImplementedError("W8A8 with int8 does not yet support bias.")
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
|
||||
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
out_dtype=input.dtype)
|
Reference in New Issue
Block a user