diff --git a/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml new file mode 100644 index 0000000000..2928d75ce4 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2 +model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.6353 + - name: "exact_match,flexible-extract" + value: 0.637 +limit: null +num_fewshot: null diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7e2e6f6ed5..0655f2b385 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -3,6 +3,7 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. """ + from typing import Optional import pytest @@ -22,12 +23,30 @@ from vllm.platforms import current_platform @pytest.mark.parametrize( "model_args", - [("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor", - QuantizationType.INT, 2560, True), - ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel", - QuantizationType.INT, 2560, True), - ("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor", - QuantizationType.INT, 2560, False)]) + [ + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "tensor", + QuantizationType.INT, + 2560, + True, + ), + ( + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "channel", + QuantizationType.INT, + 2560, + True, + ), + ( + "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", + "tensor", + QuantizationType.INT, + 2560, + False, + ), + ], +) def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args with vllm_runner(model_path, enforce_eager=True) as llm: @@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): assert output -@pytest.mark.parametrize("model_path", [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" -]) +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner, - example_prompts, model_path, - max_tokens, num_logprobs): +def test_compressed_tensors_w8a8_logprobs( + hf_runner, + vllm_runner, + example_prompts, + model_path, + max_tokens, + num_logprobs, +): dtype = "bfloat16" # skip language translation prompt for the static per tensor asym model - if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501 + if (model_path == + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym" + ): # noqa: E501 example_prompts = example_prompts[0:-1] with hf_runner(model_path, dtype=dtype) as hf_model: @@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): assert output -@pytest.mark.parametrize("model_args", [ - ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), - ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), - ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"), - ("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", - "channel"), -]) +@pytest.mark.parametrize( + "model_args", + [ + ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), + ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), + ( + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", + "channel", + ), + ( + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", + "channel", + ), + ], +) def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): model_path, strategy = model_args with vllm_runner(model_path, dtype=torch.float16) as llm: @@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): @pytest.mark.parametrize( "wNa16_args", - [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) + [ + ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), + ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), + ], +) def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: @@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner): CompressedTensorsLinearMethod) assert isinstance( qkv_proj.scheme, - (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8)) + (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8), + ) assert qkv_proj.input_scale.dtype is torch.float32 @@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") -def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +def _test_2of4_quant_models(qkv_proj, + weight_strategy, + input_strategy, + format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 - assert sparsity_map.get("Linear").format == "dense" + assert sparsity_map.get("Linear").format == format assert sparsity_map.get("Linear").sparsity_structure == "2:4" -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="Sparse FP8 is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", - "token"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", - "channel", "tensor"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", "tensor", - "tensor"), - ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", - "tensor", "token"), -]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", + "channel", + "token", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing", + "channel", + "tensor", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing", + "tensor", + "tensor", + ), + ( + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", + "token", + ), + ], +) def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): assert output -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse FP8 is not yet supported on this GPU type.") -@pytest.mark.parametrize("args_2of4", [ - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", - "channel", "token"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", "tensor", - "tensor"), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", - "tensor", "token"), -]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM", + "channel", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM", + "tensor", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM", + "tensor", + "tensor", + ), + ], +) +def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn + _test_2of4_quant_models( + qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask", + ) + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="cutlass is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM", + "channel", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM", + "tensor", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM", + "tensor", + "tensor", + ), + ], +) +def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): + model, weight_strategy, input_strategy = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert qkv_proj.scheme.weights_dtype == torch.int8 + _test_2of4_quant_models( + qkv_proj, + weight_strategy, + input_strategy, + format="sparse-24-bitmask", + ) + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output + + +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse FP8 is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", + [ + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", + "channel", + "token", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing", + "tensor", + "tensor", + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing", + "tensor", + "token", + ), + ], +) def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 with vllm_runner(model) as llm: @@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): @pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.") @pytest.mark.skipif( not sparse_cutlass_supported(), - reason="2of4 Sparse is not yet supported on this GPU type.") + reason="2of4 Sparse is not yet supported on this GPU type.", +) @pytest.mark.parametrize( "args_2of4", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")], +) def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): model = args_2of4 with vllm_runner(model) as llm: @@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 + sparsity_map = ( + qkv_proj.quant_method.quantization_config.sparsity_scheme_map + ) # noqa: E501 assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Cutlass is not yet supported on this GPU type.", +) +@pytest.mark.parametrize( + "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) +def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): + model = args_2of4 + with vllm_runner(model) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + + assert qkv_proj.scheme.weight_quant is None + assert qkv_proj.scheme.input_quant is None + assert not qkv_proj.scheme.quantized + assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map + sparsity_map = ( + qkv_proj.quant_method.quantization_config.sparsity_scheme_map + ) # noqa: E501 + assert sparsity_map.get("Linear").format == "sparse-24-bitmask" + assert sparsity_map.get("Linear").sparsity_structure == "2:4" + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + print(output) + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0e3258e4af..6ee3e9362f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -417,15 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig): return None # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - scheme = CompressedTensors24(quantized=weight_quant is not None - or input_quant is not None, - weight_quant=weight_quant, - input_quant=input_quant) + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=model_compression_config, + ) elif weight_quant is None: logger.warning_once("Acceleration for non-quantized schemes is " "not supported by Compressed Tensors. " "Falling back to UnquantizedLinearMethod") return None + else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore @@ -475,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig): :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise """ - is_valid_sparsity = (sparsity_scheme is not None - and sparsity_scheme.sparsity_structure - == SparsityStructure.TWO_FOUR.value - and sparsity_scheme.format == "dense") + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + if not is_valid_sparsity: return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 84f924b236..0fb8dfa96a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch +from compressed_tensors import CompressionFormat, ModelCompressor from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -22,26 +26,39 @@ __all__ = ["CompressedTensors24"] class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, - quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None): - + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[Dict[str, Any]] = None, + ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) @classmethod def get_min_capability(cls) -> int: # Only cutlass 3.x kernels are implemented so far return 90 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): if not sparse_cutlass_supported(): raise ValueError( "Sparse CUTLASS not supported. vLLM must be built with " @@ -49,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme): self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) # parameter to store uncompressed weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=self.weights_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + if self.do_sparse_decompress: + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter( + data=torch.empty(2, 1, dtype=torch.int64), + weight_loader=weight_loader, + ) + compressed_weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + bitmask = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed_weight) + layer.register_parameter("bitmask", bitmask) # Check if quantized, not just 2:4 Sparse if self.quantized: @@ -68,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme): data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert (self.weight_quant and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value) weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) @@ -84,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme): # register input quant scale assert (self.input_quant.strategy == QuantizationStrategy.TENSOR.value) - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) @@ -107,13 +167,25 @@ class CompressedTensors24(CompressedTensorsScheme): """ Compress weights after loading. Store compressed weight and meta tensor - + :post-condition: layer.w_compressed and layer.meta are set to the compressed weight and meta tensor in the format expected by the Cutlass kernels :param layer: The layer with the weights to be processed - + """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + + # compressed and bitmask tensors + # are no longer needed after decompression + del layer.compressed + del layer.bitmask + # torch.compile workaround if hasattr(layer, "input_scale"): layer.input_scale = torch.nn.Parameter(layer.input_scale.data, @@ -121,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme): if self.weight_quant: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: - layer.weight_scale = torch.nn.Parameter(convert_to_channelwise( - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths), - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ), + requires_grad=False, + ) else: # torch.compile workaround layer.weight_scale = torch.nn.Parameter( @@ -134,20 +209,22 @@ class CompressedTensors24(CompressedTensorsScheme): layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) layer.meta = torch.nn.Parameter(meta, requires_grad=False) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ - Returns the output tensor for the layer with 2:4 + Returns the output tensor for the layer with 2:4 sparse compressed weights, given the input tensor and bias - :param layer: The layer with 2:4 sparse compressed + :param layer: The layer with 2:4 sparse compressed weights to be used for the computation :param x: The input tensor to the layer :param bias: The bias to be added to the output tensor - :return: The output tensor of the layer + :return: The output tensor of the layer """ if self.quantized: scale = None @@ -171,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme): input_scale = layer.input_scale q_input = x - out = ops.cutlass_scaled_sparse_mm(a=q_input, - bt_nzs=layer.weight, - bt_meta=layer.meta, - scale_a=input_scale, - scale_b=layer.weight_scale, - out_dtype=self.output_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a=q_input, + bt_nzs=layer.weight, + bt_meta=layer.meta, + scale_a=input_scale, + scale_b=layer.weight_scale, + out_dtype=self.output_dtype, + bias=bias, + ) assert out.is_contiguous() return out @@ -203,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme): raise ValueError("Quantization type not supported by Cutlass") + def _decompress_bitmask_compressed_weight( + self, + compressed: torch.Tensor, + bitmask: torch.Tensor, + layer: torch.nn.Module, + ) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor using the bitmask and + return the result. -def check_24(tensor): - new_tensor = tensor.view(-1, 4) - zero_counts = (new_tensor == 0).sum(dim=1) - return (zero_counts >= 2).all().item() + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor compressed using the + sparse-24-bitmask compressor. This is different from + `cutlass_sparse_compress` which uses a different scheme (2 bits for + every nonzero element that represent the coordinate within the block + of 4). The bitmask compression here uses a bitmask to indicate the + positions of non-zero elements. + :param bitmask: The 2:4 bitmask associated with the compressed weights, + representing the positions of non-zero elements in the compressed + tensor. + :param layer: The layer whose weights need to be processed after + loading. + :return: The decompressed 2:4 sparse weight tensor. + """ + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split( + bitmask_compressed_weight: torch.Tensor, + shape, + bitmask: torch.Tensor, + ) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights: List[torch.Tensor] = [] + split_bitmask: List[torch.Tensor] = [] + split_shape: List[Tuple[int, int]] = [] + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict( + compressed=compressed, + shape=( + layer.logical_widths[0], + layer.input_size_per_partition, + ), + bitmask=bitmask, + )) + return decompressed