Fix failing MyGemma2Embedding test (#13820)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-02-25 20:33:03 +00:00
committed by GitHub
parent f75aa72732
commit 34e3494e70

View File

@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
@ -37,16 +36,12 @@ class MyGemma2Embedding(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)