Compare commits

...

1 Commits

Author SHA1 Message Date
10bc8b4e02 fix EosTokenCriteria 2024-05-20 10:43:14 +02:00

View File

@ -481,6 +481,7 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
@ -492,7 +493,7 @@ class EosTokenCriteria(StoppingCriteria):
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
return is_done