[CORE] Quantized lm-head Framework (#4442)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium-ModelCloud
2024-07-03 06:25:17 +08:00
committed by GitHub
parent 7c008c51a9
commit ee93f4f92a
48 changed files with 268 additions and 121 deletions

View File

@ -475,10 +475,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)
original_weight = linear.weight.clone()
original_lm_head = deepcopy(linear)
linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
@ -490,7 +490,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
@ -519,11 +519,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
lm_head=original_lm_head,
embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype]

View File

@ -0,0 +1,45 @@
"""Tests whether gptq models with quantized lm_head can be loaded.
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
"""
from typing import Tuple
import pytest
import torch
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
PROMPT = "On the surface of Mars, we found"
MODELS_QUANT = [(
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
def test_lm_head(
vllm_runner,
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)
if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod)
print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model

View File

@ -34,7 +34,7 @@ SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float16"
PRECISION = "float32"
@pytest.mark.parametrize(

View File

@ -83,7 +83,7 @@ def test_logits_processors(seed: int, device: str):
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
lm_head=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

View File

@ -1172,11 +1172,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -40,7 +42,7 @@ class LogitsProcessor(nn.Module):
def forward(
self,
embedding: torch.Tensor,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
@ -52,8 +54,7 @@ class LogitsProcessor(nn.Module):
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
@ -68,12 +69,13 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:

View File

@ -87,6 +87,15 @@ class QuantizationConfig(ABC):
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default
@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:

View File

@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs
@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig):
weight_bits: int,
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig):
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
return None

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.utils import get_device_capability_stateless
logger = init_logger(__name__)
@ -59,7 +60,7 @@ class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
@ -69,6 +70,7 @@ class GPTQMarlinConfig(QuantizationConfig):
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
@ -96,7 +98,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
@ -120,7 +123,10 @@ class GPTQMarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
@ -145,7 +151,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self)
return None

View File

@ -8,6 +8,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
def __init__(
self,
group_size: int,
lm_head_quantized: bool,
) -> None:
# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig):
self.perm_len = 1024
def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size})"
return (f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self)
return None

View File

@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
@ -157,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
""" # noqa: E501
def __init__(self,
@ -164,7 +168,8 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Keep the input dimensions.
@ -187,6 +192,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
@ -201,14 +214,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
@ -288,10 +301,32 @@ class VocabParallelEmbedding(torch.nn.Module):
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
self.shard_indices.org_vocab_end_index]
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
# Shard indexes for loading the weight
start_idx = self.shard_indices.org_vocab_start_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.pack_factor)
start_idx = start_idx // param.pack_factor
shard_size = shard_size // param.pack_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
@ -346,16 +381,17 @@ class ParallelLMHead(VocabParallelEmbedding):
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
org_num_embeddings, padding_size, quant_config)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)

View File

@ -412,6 +412,7 @@ class ArcticForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(
self.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
@ -434,7 +435,7 @@ class ArcticForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -328,7 +328,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -346,7 +348,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -276,7 +276,7 @@ class BloomForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.lm_head = self.transformer.word_embeddings
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -294,7 +294,7 @@ class BloomForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -303,7 +303,8 @@ class ChatGLMModel(nn.Module):
self.encoder = GLMTransformer(config, cache_config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
config.hidden_size,
quant_config=quant_config)
def forward(
self,
@ -355,7 +356,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
@ -373,7 +374,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -363,12 +363,12 @@ class CohereForCausalLM(nn.Module):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
embedding_weights = self.model.embed_tokens.weight
logits = self.logits_processor(self.model.embed_tokens,
hidden_states, sampling_metadata)
else:
embedding_weights = self.model.embed_tokens.base_layer.weight
logits = self.logits_processor(self.model.embed_tokens.base_layer,
hidden_states, sampling_metadata)
logits = self.logits_processor(embedding_weights, hidden_states,
sampling_metadata)
return logits
def sample(

View File

@ -370,6 +370,7 @@ class DbrxForCausalLM(nn.Module):
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
@ -389,7 +390,7 @@ class DbrxForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -377,7 +377,9 @@ class DeepseekForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -395,7 +397,7 @@ class DeepseekForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -465,7 +465,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -483,7 +485,7 @@ class DeepseekV2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -394,13 +394,13 @@ class FalconForCausalLM(nn.Module):
if config.tie_word_embeddings is not None
else True)
if self.tie_word_embeddings:
self.lm_head_weight = self.transformer.word_embeddings.weight
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -422,7 +422,7 @@ class FalconForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -347,8 +347,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(

View File

@ -346,8 +346,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def sample(

View File

@ -238,7 +238,7 @@ class GPT2LMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -256,7 +256,7 @@ class GPT2LMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -259,7 +259,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
@ -281,7 +281,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -229,6 +229,7 @@ class GPTJForCausalLM(nn.Module):
config.vocab_size,
config.n_embd,
bias=True,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -247,7 +248,7 @@ class GPTJForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits

View File

@ -241,6 +241,7 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -259,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states,
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits

View File

@ -253,7 +253,9 @@ class InternLM2ForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config, cache_config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -271,7 +273,7 @@ class InternLM2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states,
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits

View File

@ -273,7 +273,7 @@ class JAISLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:
@ -297,7 +297,7 @@ class JAISLMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -380,6 +380,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
@ -403,7 +404,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -125,7 +125,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
@ -255,7 +256,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -186,7 +186,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
@ -438,7 +439,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -449,6 +449,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
self.scale_width = self.config.hidden_size / self.config.dim_model_base
@ -472,10 +473,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
logits = self.logits_processor(lm_head_weight, hidden_states,
lm_head = self.lm_head
logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -331,6 +331,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
@ -350,7 +351,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -344,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -362,7 +364,7 @@ class MixtralForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -8,7 +8,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig
@ -87,7 +87,7 @@ class MLPSpeculator(nn.Module):
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
@ -169,8 +169,8 @@ class MLPSpeculator(nn.Module):
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states
logits = self.logits_processor(self.head[head_index].weight,
states, sampling_metadata)
logits = self.logits_processor(self.head[head_index], states,
sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
last_tokens = output.sampled_token_ids

View File

@ -263,7 +263,7 @@ class MPTForCausalLM(nn.Module):
self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -283,15 +283,15 @@ class OlmoForCausalLM(nn.Module):
self.config = config
self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -313,7 +313,7 @@ class OlmoForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -294,7 +294,7 @@ class OPTForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.lm_head = self.model.decoder.embed_tokens
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -312,7 +312,7 @@ class OPTForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -259,7 +259,9 @@ class OrionForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -277,7 +279,7 @@ class OrionForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -268,7 +268,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
bias=True,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -287,7 +288,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits

View File

@ -366,6 +366,7 @@ class Phi3SmallForCausalLM(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -400,7 +401,7 @@ class Phi3SmallForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)

View File

@ -365,7 +365,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.model = LlamaModel(config, cache_config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -409,7 +411,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -235,7 +235,9 @@ class QWenLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -253,7 +255,7 @@ class QWenLMHeadModel(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -316,11 +316,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.lm_head_weight = self.lm_head.weight
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -339,7 +339,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -362,7 +362,9 @@ class Qwen2MoeForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -380,7 +382,7 @@ class Qwen2MoeForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -240,7 +240,9 @@ class StablelmForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -258,7 +260,7 @@ class StablelmForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -242,7 +242,7 @@ class Starcoder2ForCausalLM(nn.Module):
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
@ -250,8 +250,8 @@ class Starcoder2ForCausalLM(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
@ -270,7 +270,7 @@ class Starcoder2ForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -310,7 +310,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -328,7 +330,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits