[Model] Remove redundant softmax when using PoolingType.STEP (#10415)

This commit is contained in:
Maybewuss
2024-11-18 18:05:36 +08:00
committed by GitHub
parent c7dec926f6
commit 01aae1cc68

View File

@ -118,14 +118,13 @@ class Pooler(nn.Module):
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
logits = hidden_states.softmax(dim=-1)
step_tag_id = self.step_tag_id
offset = 0
pooled_data_lst = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
pooled_data_i = logits[offset:offset + prompt_len]
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]