mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Embedding model pooling_type equals ALL and multi input's bug (#10494)
This commit is contained in:
@ -94,14 +94,10 @@ class Pooler(nn.Module):
|
||||
pooled_data = hidden_states[last_token_flat_indices]
|
||||
elif self.pooling_type == PoolingType.ALL:
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
pooled_data = []
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
|
||||
pooled_data_lst.append(pooled_data_i)
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
|
||||
pooled_data = torch.stack(pooled_data_lst)
|
||||
elif self.pooling_type == PoolingType.MEAN:
|
||||
# Calculate mean pooling
|
||||
cumsum = torch.cumsum(hidden_states, dim=0)
|
||||
@ -121,7 +117,7 @@ class Pooler(nn.Module):
|
||||
step_tag_id = self.step_tag_id
|
||||
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
pooled_data = []
|
||||
for prompt_len, seq_data_i in zip(
|
||||
prompt_lens, pooling_metadata.seq_data.values()):
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
@ -130,17 +126,26 @@ class Pooler(nn.Module):
|
||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
||||
|
||||
offset += prompt_len
|
||||
pooled_data_lst.append(pooled_data_i)
|
||||
|
||||
pooled_data = torch.stack(pooled_data_lst)
|
||||
pooled_data.append(pooled_data_i)
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
if self.normalize:
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
nn.functional.normalize(data, p=2, dim=1)
|
||||
for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
if self.softmax:
|
||||
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = [
|
||||
nn.functional.softmax(data, dim=-1) for data in pooled_data
|
||||
]
|
||||
else:
|
||||
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
|
||||
|
||||
pooled_outputs = [
|
||||
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||
|
Reference in New Issue
Block a user