mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CORE] Quantized lm-head Framework (#4442)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
committed by
GitHub
parent
7c008c51a9
commit
ee93f4f92a
@ -475,10 +475,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=linear.weight,
|
||||
lm_head=linear,
|
||||
embedding_bias=None)
|
||||
|
||||
original_weight = linear.weight.clone()
|
||||
original_lm_head = deepcopy(linear)
|
||||
|
||||
linear.weight[logits_processor.
|
||||
org_vocab_size:logits_processor.org_vocab_size +
|
||||
@ -490,7 +490,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
lora = lora_dict[lora_id]
|
||||
result = logits_processor._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
lm_head=linear,
|
||||
embedding_bias=None)
|
||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
@ -519,11 +519,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
lm_head=original_lm_head,
|
||||
embedding_bias=None)[:, :vocab_size]
|
||||
expected_result = logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
lm_head=original_lm_head,
|
||||
embedding_bias=None)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
|
45
tests/quantization/test_lm_head.py
Normal file
45
tests/quantization/test_lm_head.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Tests whether gptq models with quantized lm_head can be loaded.
|
||||
|
||||
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
MODELS_QUANT = [(
|
||||
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
|
||||
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
|
||||
def test_lm_head(
|
||||
vllm_runner,
|
||||
model_lm_head_quant: Tuple[str, bool],
|
||||
) -> None:
|
||||
model, lm_head_quantized = model_lm_head_quant
|
||||
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
|
||||
|
||||
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model.lm_head)
|
||||
|
||||
if lm_head_quantized:
|
||||
assert isinstance(
|
||||
lm_head_layer.linear_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
del vllm_model
|
@ -34,7 +34,7 @@ SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
|
||||
MAX_SPEC_TOKENS = 5
|
||||
|
||||
# precision
|
||||
PRECISION = "float16"
|
||||
PRECISION = "float32"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
logits_processor_output = logits_processor(
|
||||
embedding=None,
|
||||
lm_head=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
@ -1172,11 +1172,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
|
@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
@ -40,7 +42,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
@ -52,8 +54,7 @@ class LogitsProcessor(nn.Module):
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
@ -68,12 +69,13 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
def _get_logits(self, hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
|
@ -87,6 +87,15 @@ class QuantizationConfig(ABC):
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
|
||||
default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
return QuantizationConfig.get_from_keys(config, keys)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
||||
|
@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})")
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig):
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
return cls(weight_bits, group_size, desc_act)
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
@ -11,6 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool) -> None:
|
||||
is_sym: bool, lm_head_quantized: bool) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
@ -69,6 +70,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
@ -96,7 +98,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})")
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -120,7 +123,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
return cls(weight_bits, group_size, desc_act, is_sym)
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -145,7 +151,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
@ -8,6 +8,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MarlinConfig(group_size={self.group_size})"
|
||||
return (f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(group_size)
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
@ -157,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
quant_config: quant config for the layer
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
@ -164,7 +168,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
@ -187,6 +192,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
@ -201,14 +214,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.weight, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
@ -288,10 +301,32 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
return ret
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
parallel_dim = param.parallel_dim
|
||||
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
|
||||
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
|
||||
self.shard_indices.org_vocab_end_index]
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
assert param.data.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# Shard indexes for loading the weight
|
||||
start_idx = self.shard_indices.org_vocab_start_index
|
||||
shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
||||
|
||||
# If param packed on the same dim we are sharding on, then
|
||||
# need to adjust offsets of loaded weight by pack_factor.
|
||||
if packed_dim is not None and packed_dim == output_dim:
|
||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||
param.pack_factor)
|
||||
start_idx = start_idx // param.pack_factor
|
||||
shard_size = shard_size // param.pack_factor
|
||||
else:
|
||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||
|
||||
# Copy the data.
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
@ -346,16 +381,17 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size)
|
||||
org_num_embeddings, padding_size, quant_config)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
@ -412,6 +412,7 @@ class ArcticForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.num_experts = config.num_local_experts
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
@ -434,7 +435,7 @@ class ArcticForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -328,7 +328,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -346,7 +348,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -276,7 +276,7 @@ class BloomForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = BloomModel(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -294,7 +294,7 @@ class BloomForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -303,7 +303,8 @@ class ChatGLMModel(nn.Module):
|
||||
self.encoder = GLMTransformer(config, cache_config, quant_config)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -355,7 +356,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.transformer.output_layer.weight
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -373,7 +374,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -363,12 +363,12 @@ class CohereForCausalLM(nn.Module):
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
||||
if is_not_lora:
|
||||
embedding_weights = self.model.embed_tokens.weight
|
||||
logits = self.logits_processor(self.model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
else:
|
||||
embedding_weights = self.model.embed_tokens.base_layer.weight
|
||||
logits = self.logits_processor(self.model.embed_tokens.base_layer,
|
||||
hidden_states, sampling_metadata)
|
||||
|
||||
logits = self.logits_processor(embedding_weights, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
|
@ -370,6 +370,7 @@ class DbrxForCausalLM(nn.Module):
|
||||
config.d_model,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -389,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -377,7 +377,9 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -395,7 +397,7 @@ class DeepseekForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -465,7 +465,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -483,7 +485,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -394,13 +394,13 @@ class FalconForCausalLM(nn.Module):
|
||||
if config.tie_word_embeddings is not None
|
||||
else True)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -422,7 +422,7 @@ class FalconForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.embed_tokens.weight,
|
||||
hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
|
@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.embed_tokens.weight,
|
||||
hidden_states, sampling_metadata)
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
|
@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
||||
lora_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
|
@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.embed_out.weight, hidden_states,
|
||||
logits = self.logits_processor(self.embed_out, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = InternLM2Model(config, cache_config, quant_config)
|
||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.output = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.output.weight, hidden_states,
|
||||
logits = self.logits_processor(self.output, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = JAISModel(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.lm_head = self.transformer.wte
|
||||
if hasattr(config, "width_scale"):
|
||||
self.output_logits_scale = config.width_scale
|
||||
else:
|
||||
@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.language_model.org_vocab_size)
|
||||
org_num_embeddings=self.language_model.org_vocab_size,
|
||||
quant_config=quant_config)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.language_model.org_vocab_size)
|
||||
org_num_embeddings=self.language_model.org_vocab_size,
|
||||
quant_config=quant_config)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
lm_head = self.model.embed_tokens
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
logits = self.logits_processor(lm_head_weight, hidden_states,
|
||||
lm_head = self.lm_head
|
||||
logits = self.logits_processor(lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
||||
@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
|
||||
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||
(self.max_speculative_tokens - 1))
|
||||
|
||||
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
||||
|
||||
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
previous_hidden_states = states
|
||||
|
||||
logits = self.logits_processor(self.head[head_index].weight,
|
||||
states, sampling_metadata)
|
||||
logits = self.logits_processor(self.head[head_index], states,
|
||||
sampling_metadata)
|
||||
|
||||
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
|
||||
last_tokens = output.sampled_token_ids
|
||||
|
@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.transformer = MPTModel(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.model = OlmoModel(config, cache_config, quant_config)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OPTModel(config, cache_config, quant_config)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.lm_head = self.model.decoder.embed_tokens
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OrionModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=True)
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
|
@ -366,6 +366,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -400,7 +401,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
if self.dummy_token_indices is not None and logits is not None:
|
||||
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||
|
@ -365,7 +365,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -409,7 +411,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -235,7 +235,9 @@ class QWenLMHeadModel(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = QWenModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -253,7 +255,7 @@ class QWenLMHeadModel(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -362,7 +362,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -380,7 +382,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -240,7 +240,9 @@ class StablelmForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -258,7 +260,7 @@ class StablelmForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
@ -310,7 +310,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -328,7 +330,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
Reference in New Issue
Block a user