[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)

This commit is contained in:
Robert Shaw
2024-07-18 22:39:18 -04:00
committed by GitHub
parent d4201e06d5
commit dbe5588554
11 changed files with 301 additions and 91 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.758
- name: "exact_match,flexible-extract"
value: 0.759
limit: 1000
num_fewshot: 5

View File

@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml

View File

@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()

View File

@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size, self.quant_method.create_weights(self,
[self.output_size], self.input_size, self.input_size, [self.output_size],
self.output_size, self.params_dtype) self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3. the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None): output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output=gather_output, gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output=False, gather_output=False,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")

View File

@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsWNA16) CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match, QuantizationType, find_matched_target, is_activation_quantization_format,
is_activation_quantization_format) should_ignore_layer)
from vllm.platforms import current_platform from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
quant_format: str): quant_format: str):
self.ignore = ignore self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None) quant_format: str = config.get("format", None)
@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs # details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the # pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use. # quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items(): for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
layer_quant_details[target] = {} target_scheme_map[target] = {}
layer_quant_details[target][ target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj( "weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights")) quant_config.get("weights"))
try: try:
layer_quant_details[target][ target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj( "input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations")) quant_config.get("input_activations"))
except Exception: except Exception:
layer_quant_details[target]["input_activations"] = None target_scheme_map[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details, return cls(target_scheme_map=target_scheme_map,
ignore=ignore, ignore=ignore,
quant_format=quant_format) quant_format=quant_format)
@ -167,8 +169,9 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_channel_group and input_quant_none and is_symmetric return (is_channel_group and input_quant_none and is_symmetric
and is_static) and is_static)
def _get_schema(self, weight_quant: BaseModel, def _get_scheme_from_parts(
input_quant: BaseModel) -> "CompressedTensorsScheme": self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError( raise NotImplementedError(
"No compressed-tensors compatible scheme was found.") "No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
layer_type_name = find_first_name_or_class_match( ignore: List of layer_names or nn.Module names to be ignored.
name="", targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
We first check whether a layer is in the ignore group and use
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer, module=layer,
targets=self.layer_quant_details.keys(), targets=self.target_scheme_map.keys())
check_contains=True)
if layer_type_name is None: # Find the quant_scheme
raise ValueError(f"Could not matching target for layer {layer}") scheme = self.target_scheme_map[matched_target]
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( return self._get_scheme_from_parts(
layer_type_name, None) weight_quant=scheme["weights"],
if layer_quant_details is None: input_quant=scheme["input_activations"])
raise ValueError(
f"Could not find quantization details for {layer}.")
scheme = self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace) # (e.g. fp8 needs ada lovelace)
@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param the necessary parameters for the layer. See LinearMethodBase for param
details details
""" """
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")
scheme = self.quantization_config.get_scheme(layer=layer) scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights( scheme.create_weights(
layer=layer, layer=layer,
input_size=input_size, input_size=input_size,

View File

@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
device="cuda",
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)

View File

@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS return format in _ACTIVATION_QUANTIZATION_FORMATS
def find_first_name_or_class_match( # fused_name: List[shard_name]
name: str, _FUSED_LAYER_NAME_MAPPING = {
module: Module, "qkv_proj": ["q_proj", "k_proj", "v_proj"],
targets: Iterable[str], "gate_up_proj": ["gate_proj", "up_proj"]
check_contains: bool = False) -> Optional[str]: }
"""
Helper function to map the quantization details listed in the config
for a given list of targets against each model layer. First uses the
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
:param name: layer name
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module :param module: torch.nn.Module
:param targets: list of targets to match the layer against :param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
""" """
return _find_first_match(name, targets) or _find_first_match( if layer_name is None:
module.__class__.__name__, targets, check_contains) layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str, def _find_first_match(value: str,
@ -121,13 +202,29 @@ def _find_first_match(value: str,
""" """
for target in targets: for target in targets:
if target.startswith("re:"): if _is_equal_or_regex_match(value,
pattern = target[3:] target,
if re.match(pattern, value): check_contains=check_contains):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
return target return target
return None return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False

View File

@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
total_num_heads, total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_attn",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_fc",
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size) intermediate_size)
@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
hidden_size) hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, cache_config, quant_config) self.attn = GPT2Attention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config) self.mlp = GPT2MLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
config: GPT2Config, config: GPT2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda: GPT2Block(config, cache_config, quant_config)) lambda prefix: GPT2Block(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config) self.transformer = GPT2Model(config,
cache_config,
quant_config,
prefix="transformer")
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()

View File

@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size, input_size=hidden_size,
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size, self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
total_num_kv_heads=self.total_num_kv_heads, total_num_kv_heads=self.total_num_kv_heads,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim, input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
config: LlamaConfig, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=attention_bias, bias=attention_bias,
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "mlp_bias", False), bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config, lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config)) quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.model = LlamaModel(config, self.model = LlamaModel(config,
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:

View File

@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
intermediate_size: int, intermediate_size: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None): tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
num_experts, num_experts,
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=None) quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts, self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k, top_k=top_k,
@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
reduce_results=True, reduce_results=True,
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size) tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj",
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = MixtralMoE( self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, lambda: MixtralDecoderLayer( config.num_hidden_layers,
config, cache_config, quant_config=quant_config)) lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.model = MixtralModel(config, self.model = MixtralModel(config,
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Tuple from typing import Dict, List, Protocol, Tuple
import torch import torch
from torch.func import functional_call from torch.func import functional_call
@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity): class PPMissingLayer(torch.nn.Identity):
""" """
A placeholder layer for missing layers in a pipeline parallel model. A placeholder layer for missing layers in a pipeline parallel model.
@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
def make_layers( def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking """Make a list of layers with the given layer function, taking
pipeline parallelism into account. pipeline parallelism into account.
@ -131,8 +142,8 @@ def make_layers(
get_pp_group().world_size) get_pp_group().world_size)
modules = torch.nn.ModuleList( modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [ [PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn()) maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for _ in range(start_layer, end_layer) for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules return start_layer, end_layer, modules