mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[ Misc ] non-uniform quantization via compressed-tensors
for Llama
(#6515)
This commit is contained in:
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.758
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.759
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
|
@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
prefix=prefix)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
quant_config: Quantization configure.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None):
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
prefix: Optional[str] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
weight_loader=self.weight_loader,
|
||||
prefix=prefix)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = None):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = None):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = None):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
weight_loader=self.weight_loader,
|
||||
prefix=prefix)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsScheme, CompressedTensorsUnquantized,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
QuantizationType, find_first_name_or_class_match,
|
||||
is_activation_quantization_format)
|
||||
QuantizationType, find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
|
||||
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
|
||||
quant_format: str):
|
||||
|
||||
self.ignore = ignore
|
||||
self.layer_quant_details = layer_quant_details
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
layer_quant_details: Dict[str, Any] = dict()
|
||||
target_scheme_map: Dict[str, Any] = dict()
|
||||
ignore: List[str] = config.get("ignore", None)
|
||||
quant_format: str = config.get("format", None)
|
||||
|
||||
@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
for key, quant_config in config["config_groups"].items():
|
||||
for _, quant_config in config["config_groups"].items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
layer_quant_details[target] = {}
|
||||
layer_quant_details[target][
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target][
|
||||
"weights"] = QuantizationArgs.parse_obj(
|
||||
quant_config.get("weights"))
|
||||
try:
|
||||
layer_quant_details[target][
|
||||
target_scheme_map[target][
|
||||
"input_activations"] = QuantizationArgs.parse_obj(
|
||||
quant_config.get("input_activations"))
|
||||
except Exception:
|
||||
layer_quant_details[target]["input_activations"] = None
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
|
||||
return cls(layer_quant_details=layer_quant_details,
|
||||
return cls(target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format)
|
||||
|
||||
@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (is_channel_group and input_quant_none and is_symmetric
|
||||
and is_static)
|
||||
|
||||
def _get_schema(self, weight_quant: BaseModel,
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
||||
def get_scheme(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
layer_type_name = find_first_name_or_class_match(
|
||||
name="",
|
||||
ignore: List of layer_names or nn.Module names to be ignored.
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
We first check whether a layer is in the ignore group and use
|
||||
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
|
||||
|
||||
We then detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for infernece.
|
||||
"""
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(layer_name, ignore=self.ignore):
|
||||
return CompressedTensorsUnquantized()
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.layer_quant_details.keys(),
|
||||
check_contains=True)
|
||||
targets=self.target_scheme_map.keys())
|
||||
|
||||
if layer_type_name is None:
|
||||
raise ValueError(f"Could not matching target for layer {layer}")
|
||||
# Find the quant_scheme
|
||||
scheme = self.target_scheme_map[matched_target]
|
||||
|
||||
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
|
||||
layer_type_name, None)
|
||||
if layer_quant_details is None:
|
||||
raise ValueError(
|
||||
f"Could not find quantization details for {layer}.")
|
||||
|
||||
scheme = self._get_schema(
|
||||
weight_quant=layer_quant_details["weights"],
|
||||
input_quant=layer_quant_details["input_activations"])
|
||||
return self._get_scheme_from_parts(
|
||||
weight_quant=scheme["weights"],
|
||||
input_quant=scheme["input_activations"])
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer_name = extra_weight_attrs.get("prefix")
|
||||
|
||||
scheme = self.quantization_config.get_scheme(layer=layer)
|
||||
scheme = self.quantization_config.get_scheme(layer, layer_name)
|
||||
scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
|
@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
|
||||
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
|
||||
|
@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def find_first_name_or_class_match(
|
||||
name: str,
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
check_contains: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Helper function to map the quantization details listed in the config
|
||||
for a given list of targets against each model layer. First uses the
|
||||
layer name to try and find a match. If no name match is found, uses
|
||||
the layer class name. Returns None otherwise.
|
||||
# fused_name: List[shard_name]
|
||||
_FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
:param name: layer name
|
||||
|
||||
def should_ignore_layer(layer_name: Optional[str],
|
||||
ignore: Iterable[str]) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
||||
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_matched_target(layer_name: Optional[str], module: Module,
|
||||
targets: Iterable[str]) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
return _find_first_match(name, targets) or _find_first_match(
|
||||
module.__class__.__name__, targets, check_contains)
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets,
|
||||
True))
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(f"Unable to find matching target for {module} in the "
|
||||
"compressed-tensors config.")
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(value: str,
|
||||
@ -121,13 +202,29 @@ def _find_first_match(value: str,
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value,
|
||||
target,
|
||||
check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return target
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return target
|
||||
return True
|
||||
elif target == value:
|
||||
return target
|
||||
return None
|
||||
return True
|
||||
return False
|
||||
|
@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config, cache_config, quant_config)
|
||||
self.attn = GPT2Attention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPT2MLP(inner_dim,
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda: GPT2Block(config, cache_config, quant_config))
|
||||
lambda prefix: GPT2Block(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config, cache_config, quant_config)
|
||||
self.transformer = GPT2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda: LlamaDecoderLayer(config=config,
|
||||
lambda prefix: LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = LlamaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
|
||||
num_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None)
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
self.experts = FusedMoE(num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size)
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, lambda: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config))
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = MixtralModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Dict, List, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prefix="",
|
||||
) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
A placeholder layer for missing layers in a pipeline parallel model.
|
||||
@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str,
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
@ -131,8 +142,8 @@ def make_layers(
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn())
|
||||
for _ in range(start_layer, end_layer)
|
||||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
Reference in New Issue
Block a user