[ Misc ] non-uniform quantization via compressed-tensors for Llama (#6515)

This commit is contained in:
Robert Shaw
2024-07-18 22:39:18 -04:00
committed by GitHub
parent d4201e06d5
commit dbe5588554
11 changed files with 301 additions and 91 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.758
- name: "exact_match,flexible-extract"
value: 0.759
limit: 1000
num_fewshot: 5

View File

@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml

View File

@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
):
super().__init__()

View File

@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)
if bias:
self.bias = Parameter(
@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None):
output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
prefix=prefix)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
def weight_loader(self,
param: Parameter,
@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
def weight_loader(self,
param: Parameter,
@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")

View File

@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match,
is_activation_quantization_format)
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
quant_format: str):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items():
for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target][
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
try:
layer_quant_details[target][
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
layer_quant_details[target]["input_activations"] = None
target_scheme_map[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details,
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format)
@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_channel_group and input_quant_none and is_symmetric
and is_static)
def _get_schema(self, weight_quant: BaseModel,
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision
@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
layer_type_name = find_first_name_or_class_match(
name="",
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
We first check whether a layer is in the ignore group and use
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.layer_quant_details.keys(),
check_contains=True)
targets=self.target_scheme_map.keys())
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
layer_type_name, None)
if layer_quant_details is None:
raise ValueError(
f"Could not find quantization details for {layer}.")
scheme = self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")
scheme = self.quantization_config.get_scheme(layer=layer)
scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer=layer,
input_size=input_size,

View File

@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
device="cuda",
dtype=params_dtype),
requires_grad=False)

View File

@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS
def find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Helper function to map the quantization details listed in the config
for a given list of targets against each model layer. First uses the
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
:param name: layer name
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains)
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str,
@ -121,13 +202,29 @@ def _find_first_match(value: str,
"""
for target in targets:
if _is_equal_or_regex_match(value,
target,
check_contains=check_contains):
return target
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
return True
elif check_contains:
if target.lower() in value.lower():
return target
return True
elif target == value:
return target
return None
return True
return False

View File

@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
intermediate_size: int,
config: GPT2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, cache_config, quant_config)
self.attn = GPT2Attention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attn")
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config)
self.mlp = GPT2MLP(inner_dim,
config,
quant_config,
prefix=f"{prefix}.mlp")
def forward(
self,
@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda: GPT2Block(config, cache_config, quant_config))
lambda prefix: GPT2Block(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config)
self.transformer = GPT2Model(config,
cache_config,
quant_config,
prefix="transformer")
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda: LlamaDecoderLayer(config=config,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config))
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config=lora_config)
lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:

View File

@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None):
tp_size: Optional[int] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None)
quant_config=None,
prefix=f"{prefix}.gate")
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size)
tp_size=tp_size,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, lambda: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config))
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.model = MixtralModel(config,
cache_config,
quant_config,
lora_config=lora_config)
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Tuple
from typing import Dict, List, Protocol, Tuple
import torch
from torch.func import functional_call
@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds
class LayerFn(Protocol):
def __call__(
self,
prefix="",
) -> torch.nn.Module:
...
class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
@ -131,8 +142,8 @@ def make_layers(
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn())
for _ in range(start_layer, end_layer)
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules