[Bugfix] Embedding model pooling_type equals ALL and multi input's bug (#10494)

This commit is contained in:
Xiaoyu Zhang
2024-11-21 22:40:02 +08:00
committed by GitHub
parent d5ec121f95
commit 4d676f0852

View File

@ -94,14 +94,10 @@ class Pooler(nn.Module):
pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL:
offset = 0
pooled_data_lst = []
pooled_data = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
pooled_data_lst.append(pooled_data_i)
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
pooled_data = torch.stack(pooled_data_lst)
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
@ -121,7 +117,7 @@ class Pooler(nn.Module):
step_tag_id = self.step_tag_id
offset = 0
pooled_data_lst = []
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
@ -130,17 +126,26 @@ class Pooler(nn.Module):
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
offset += prompt_len
pooled_data_lst.append(pooled_data_i)
pooled_data = torch.stack(pooled_data_lst)
pooled_data.append(pooled_data_i)
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.normalize(data, p=2, dim=1)
for data in pooled_data
]
else:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
if isinstance(pooled_data, list):
pooled_data = [
nn.functional.softmax(data, dim=-1) for data in pooled_data
]
else:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data