[ Misc ] Support Fp8 via llm-compressor (#6110)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Robert Shaw
2024-07-07 16:42:11 -04:00
committed by GitHub
parent 333306a252
commit abfe705a02
17 changed files with 603 additions and 372 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.752
limit: 250
num_fewshot: 5

View File

@ -1,4 +1,4 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
tasks:
- name: "gsm8k"

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.728
- name: "exact_match,flexible-extract"
value: 0.728
limit: 250
num_fewshot: 5

View File

@ -1,2 +1,4 @@
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml

View File

@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE

View File

@ -24,7 +24,8 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
def launch_lm_eval(eval_config):
model_args = f"pretrained={eval_config['model_name']}," \
f"tensor_parallel_size={TP_SIZE}"
f"tensor_parallel_size={TP_SIZE}," \
f"add_bos_token=true"
results = lm_eval.simple_evaluate(
model="vllm",

View File

@ -9,7 +9,8 @@ import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8, CompressedTensorsWNA16)
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)
@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
torch.float8_e4m3fn)
expected_type = torch.int8
assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output
def test_compressed_tensors_fp8(vllm_runner):
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight_scale.dtype is torch.float32
# should be scalars after processing
assert len(qkv_proj.input_scale.shape) == 0
assert len(qkv_proj.weight_scale.shape) == 0
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output

View File

@ -9,10 +9,11 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8, CompressedTensorsWNA16)
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
QuantizationType, find_first_name_or_class_match)
from vllm.platforms import current_platform
@ -117,6 +118,40 @@ class CompressedTensorsConfig(QuantizationConfig):
return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm we have floating points.
if not (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT):
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True
# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
if not (is_symmetric_activation and is_per_tensor_activation):
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
@ -147,13 +182,20 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.int_quantized.value:
if (self.quant_format == CompressionFormat.int_quantized.value or
self.quant_format == CompressionFormat.float_quantized.value):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic)
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)
raise NotImplementedError(
@ -187,7 +229,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return layer.scheme.process_weights_after_loading(layer)
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,

View File

@ -1,8 +1,19 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16_24 import ( # noqa: F401
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
]

View File

@ -1,109 +0,0 @@
from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs
class CompressedTensorsW8A8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we convert to the per-channel case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if (self.strategy == QuantizationStrategy.TENSOR
and len(self.logical_widths) > 1):
# Load the N per-tensor scales into the channelwise buffer.
weight_scale_channel = torch.empty(
(sum(self.logical_widths), 1),
dtype=torch.float32,
device=layer.weight_scale.device)
start = 0
for idx, logical_width in enumerate(self.logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
start = end
layer.weight_scale = Parameter(weight_scale_channel,
requires_grad=False)
# transpose weights for cutlass.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
# WEIGHT SCALE
shape: Union[Tuple[int], Tuple[int, int]]
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (sum(self.logical_widths), 1)
else:
shape = (len(self.logical_widths), )
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("weight_scale", weight_scale)
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"output_dim": 0,
})
else:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"needs_scalar_to_array": True,
})
# WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# INPUT SCALE
# Static quantization: load from disk.
if self.is_static_input_scheme:
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
# Dynamic quantization: set to None.
else:
layer.input_scale = None
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype)

View File

@ -0,0 +1,87 @@
from typing import Callable, List, Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, input_dynamic: bool):
self.input_dynamic = input_dynamic
self.cutlass_fp8_supported = cutlass_fp8_supported()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
def process_weights_after_loading(self, layer) -> None:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(max_w_scale,
requires_grad=False)
if self.input_dynamic:
layer.input_scale = None
else:
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
del params_dtype
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if not self.input_dynamic:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)

View File

@ -0,0 +1,85 @@
from typing import Callable, List
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(self.logical_widths) > 1
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
# WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param(output_partition_sizes,
**layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("weight_scale", scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("input_scale", scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)

View File

@ -9,6 +9,7 @@ from torch.nn import Module
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
@ -11,11 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
pack_fp8_to_int32)
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
@ -25,13 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return ops.cutlass_scaled_mm_supports_fp8(capability)
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@ -117,23 +110,6 @@ class Fp8LinearMethod(LinearMethodBase):
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
def _create_scale_param(
self,
scale_name: str,
layer: torch.nn.Module,
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> None:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float8_e4m3fn).min
layer.register_parameter(scale_name, scale)
set_weight_attrs(scale, {
**extra_weight_attrs,
"needs_scalar_to_array": True,
})
def create_weights(
self,
layer: torch.nn.Module,
@ -147,7 +123,6 @@ class Fp8LinearMethod(LinearMethodBase):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
@ -173,144 +148,50 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
self._create_scale_param(
scale_name="weight_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="input_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if self.use_marlin:
layer.marlin_state = GPTQMarlinState.REPACK
def prepare_layer_for_marlin(self, layer: Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
assert layer.marlin_state == GPTQMarlinState.REPACK
layer.marlin_state = GPTQMarlinState.READY
device = layer.weight.device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1,
num_bits=8,
)
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
# Allocate marlin workspace
max_workspace_size = (
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = workspace
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.input_scale = None
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
return
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. As a result, we skip dequant -> requant
# since we already have quantized QKV together.
# Sample Model with fused checkpoint:
# * nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
layer.weight[start:end, :], layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight = layer.weight
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
layer.input_scale = None
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
def apply(self,
layer: torch.nn.Module,
@ -318,65 +199,22 @@ class Fp8LinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_marlin:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=layer.weight,
b_scales=layer.weight_scale,
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
bias=bias)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
else:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
return torch.narrow(output, 0, 0, x.shape[0])
cutlass_fp8_supported=self.cutlass_fp8_supported)
class Fp8MoEMethod(FusedMoEMethodBase):
@ -399,8 +237,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
intermediate_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
layer.process_after_load = True
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@ -465,9 +301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.a2_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
@ -531,7 +364,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
shard_size, :],
layer.w13_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :] = per_tensor_quantize(
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
@ -596,23 +429,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@ -11,20 +11,16 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM,
GPTQ_MARLIN_TILE)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits: int):

View File

@ -1,9 +1,11 @@
"""This file is used for /tests and /benchmarks"""
import random
from typing import Optional
import numpy
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.format_24 import (
mask_creator, sparse_semi_structured_from_dense_cutlass)
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
MARLIN_TILE = 16
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
def is_marlin_supported():
@ -22,7 +32,92 @@ def is_marlin_supported():
return capability[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
device = layer.weight.device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
num_bits = 8
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1,
scale_perm=marlin_scale_perm[num_bits],
scale_perm_single=marlin_scale_perm_single[num_bits])
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
# Allocate marlin workspace
max_workspace_size = (part_size_n //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = workspace
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"

View File

@ -0,0 +1,163 @@
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
def cutlass_fp8_supported() -> bool:
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return ops.cutlass_scaled_mm_supports_fp8(capability)
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def create_per_tensor_scale_param(
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> Parameter:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**extra_weight_attrs
})
return scale
def create_per_channel_scale_param(output_partition_sizes: List[int],
**extra_weight_attrs) -> Parameter:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
return scale
def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Create channelwise buffer
weight_scale_channel = torch.empty((sum(logical_widths), 1),
dtype=torch.float32,
device=weight_scale.device)
# Expand each scale to match the size of each logical matrix.
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = weight_scale[idx]
start = end
return weight_scale_channel
def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
torch.float8_e4m3fn).min)
# If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :],
weight_scale[idx])
weight[start:end, :], _ = ops.scaled_fp8_quant(
weight_dq, max_w_scale)
start = end
return max_w_scale, weight
def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if bias is None and cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale)
else:
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])
def apply_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
if bias is not None:
raise NotImplementedError("W8A8 with int8 does not yet support bias.")
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype)