🚨 Remove Constrained Beam Search decoding strategy (#40518)
* Squashed remove-constrastive-search * sweeep ready for tests * testing... * whoops * ops * tests fix * tests green, changed handling of deprecated methods * tests gone after green * restore and deprecate beam obkects * restore and deprecate constraint objects * fix ci * review
This commit is contained in:
committed by
GitHub
parent
564be6d895
commit
8564e210ca
@ -39,7 +39,6 @@
|
||||
| [كيفية ضبط نموذج بدقة على التلخيص](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb)| يوضح كيفية معالجة البيانات مسبقًا وضبط نموذج مُدرَّب مسبقًا بدقة على XSUM. | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)|
|
||||
| [كيفية تدريب نموذج لغة من البداية](https://github.com/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| تسليط الضوء على جميع الخطوات لتدريب نموذج Transformer بشكل فعال على بيانات مخصصة | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)|
|
||||
| [كيفية إنشاء نص](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| كيفية استخدام أساليب فك التشفير المختلفة لإنشاء اللغة باستخدام المحولات | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)|
|
||||
| [كيفية إنشاء نص (مع قيود)](https://github.com/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| كيفية توجيه إنشاء اللغة باستخدام القيود التي يوفرها المستخدم | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)|
|
||||
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| كيف يدفع Reformer حدود النمذجة اللغوية | [](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
|
||||
|
||||
#### رؤية الكمبيوتر[[pytorch-cv]]
|
||||
|
@ -300,28 +300,6 @@ generation_output[:2]
|
||||
[[autodoc]] MaxTimeCriteria
|
||||
- __call__
|
||||
|
||||
## Constraints
|
||||
|
||||
[`Constraint`] を使用すると、生成時に出力に特定のトークンまたはシーケンスが含まれるように強制できます。これは PyTorch 実装でのみ利用可能であることに注意してください。
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
[[autodoc]] PhrasalConstraint
|
||||
|
||||
[[autodoc]] DisjunctiveConstraint
|
||||
|
||||
[[autodoc]] ConstraintListState
|
||||
|
||||
## BeamSearch
|
||||
|
||||
[[autodoc]] BeamScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
[[autodoc]] ConstrainedBeamSearchScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -305,28 +305,6 @@ generation_output[:2]
|
||||
[[autodoc]] EosTokenCriteria
|
||||
- __call__
|
||||
|
||||
## Constraint [[transformers.Constraint]]
|
||||
|
||||
[`Constraint`]는 생성 출력에 특정 토큰이나 시퀀스를 강제로 포함시키는 데 사용됩니다. 이 기능은 PyTorch 구현에만 제공됩니다.
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
[[autodoc]] PhrasalConstraint
|
||||
|
||||
[[autodoc]] DisjunctiveConstraint
|
||||
|
||||
[[autodoc]] ConstraintListState
|
||||
|
||||
## 빔 검색 (BeamSearch) [[transformers.BeamScorer]]
|
||||
|
||||
[[autodoc]] BeamScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
[[autodoc]] ConstrainedBeamSearchScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## 스트리머 (Streamers) [[transformers.TextStreamer]]
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -295,28 +295,6 @@ generation_output[:2]
|
||||
[[autodoc]] MaxTimeCriteria
|
||||
- __call__
|
||||
|
||||
## Constraints
|
||||
|
||||
可以使用[`Constraint`]来强制生成结果包含输出中的特定tokens或序列。请注意,这仅适用于我们的PyTorch实现。
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
[[autodoc]] PhrasalConstraint
|
||||
|
||||
[[autodoc]] DisjunctiveConstraint
|
||||
|
||||
[[autodoc]] ConstraintListState
|
||||
|
||||
## BeamSearch
|
||||
|
||||
[[autodoc]] BeamScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
[[autodoc]] ConstrainedBeamSearchScorer
|
||||
- process
|
||||
- finalize
|
||||
|
||||
## Streamers
|
||||
|
||||
[[autodoc]] TextStreamer
|
||||
|
@ -56,7 +56,6 @@ You can open any page of the documentation as a notebook in Colab (there is a bu
|
||||
| [How to fine-tune a model on summarization](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb)| Show how to preprocess the data and fine-tune a pretrained model on XSUM. | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)|
|
||||
| [How to train a language model from scratch](https://github.com/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)|
|
||||
| [How to generate text](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)|
|
||||
| [How to generate text (with constraints)](https://github.com/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| How to guide language generation with user-provided constraints | [](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| [](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)|
|
||||
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| How Reformer pushes the limits of language modeling | [](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
|
||||
|
||||
#### Computer Vision[[pytorch-cv]]
|
||||
|
@ -1,7 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# TODO joao, manuel: remove in v4.58.0
|
||||
class Constraint(ABC):
|
||||
r"""Abstract base class for all constraints that can be applied during generation.
|
||||
It must define how the constraint can be satisfied.
|
||||
@ -18,6 +24,9 @@ class Constraint(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
logger.warning_once(
|
||||
"Importing `Constraint` classes is deprecated and will be removed in v4.58.0. Constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search. Please import using `from transformers.generation import Constraint` instead."
|
||||
)
|
||||
# test for the above condition
|
||||
self.test()
|
||||
|
||||
|
@ -20,10 +20,13 @@ from typing import Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import add_start_docstrings
|
||||
from ..utils import add_start_docstrings, logging
|
||||
from .beam_constraints import Constraint, ConstraintListState
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
PROCESS_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
||||
@ -118,6 +121,305 @@ class BeamScorer(ABC):
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
|
||||
|
||||
class BeamSearchScorer(BeamScorer):
|
||||
r"""
|
||||
[`BeamScorer`] implementing standard beam search decoding.
|
||||
|
||||
Adapted in part from [Facebook's XLM beam search
|
||||
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
|
||||
|
||||
Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
|
||||
implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search.
|
||||
device (`torch.device`):
|
||||
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
||||
allocated.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
||||
`length_penalty` < 0.0 encourages shorter sequences.
|
||||
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
||||
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
||||
beam search algorithm).
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformers.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`, *optional*, defaults to 1):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://huggingface.co/papers/1610.02424) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
The maximum length of the sequence to be generated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[Union[bool, str]] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
logger.warning_once(
|
||||
"`BeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
|
||||
)
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
self.num_beam_groups = num_beam_groups
|
||||
self.group_size = self.num_beams // self.num_beam_groups
|
||||
|
||||
self._is_init = False
|
||||
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
|
||||
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
|
||||
self._beam_hyps = [
|
||||
BeamHypotheses(
|
||||
num_beams=self.group_size,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
max_length=max_length,
|
||||
)
|
||||
for _ in range(batch_size * self.num_beam_groups)
|
||||
]
|
||||
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
|
||||
# in the i-th mini-batch is complete.
|
||||
self._done = torch.tensor(
|
||||
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
||||
" one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
||||
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
group_index: Optional[int] = 0,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
||||
cur_len = input_ids.shape[-1] + 1
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if batch_size != (input_ids.shape[0] // self.group_size):
|
||||
if self.num_beam_groups > 1:
|
||||
raise ValueError(
|
||||
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
||||
f"size of {self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
||||
f"{self.group_size} is expected by the beam scorer."
|
||||
)
|
||||
|
||||
device = input_ids.device
|
||||
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
||||
if self._done[batch_group_idx]:
|
||||
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
|
||||
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
||||
if eos_token_id is None or pad_token_id is None:
|
||||
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
||||
# pad the batch
|
||||
next_beam_scores[batch_idx, :] = 0
|
||||
next_beam_tokens[batch_idx, :] = pad_token_id
|
||||
next_beam_indices[batch_idx, :] = 0
|
||||
continue
|
||||
|
||||
# next tokens for this sentence
|
||||
beam_idx = 0
|
||||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
if beam_indices is not None:
|
||||
beam_index = beam_indices[batch_beam_idx]
|
||||
beam_index = beam_index + (batch_beam_idx,)
|
||||
else:
|
||||
beam_index = None
|
||||
|
||||
self._beam_hyps[batch_group_idx].add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
generated_len=cur_len - decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beam_scores[batch_idx, beam_idx] = next_score
|
||||
next_beam_tokens[batch_idx, beam_idx] = next_token
|
||||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
||||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.group_size:
|
||||
break
|
||||
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
||||
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"next_beam_scores": next_beam_scores.view(-1),
|
||||
"next_beam_tokens": next_beam_tokens.view(-1),
|
||||
"next_beam_indices": next_beam_indices.view(-1),
|
||||
}
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
final_beam_scores: torch.FloatTensor,
|
||||
final_beam_tokens: torch.LongTensor,
|
||||
final_beam_indices: torch.LongTensor,
|
||||
max_length: int,
|
||||
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id = torch.tensor(eos_token_id)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_group_idx]:
|
||||
continue
|
||||
|
||||
# all open beam hypotheses are added to the beam hypothesis
|
||||
# beam hypothesis class automatically keeps the best beams
|
||||
for index_per_group in range(self.group_size):
|
||||
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_indices = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i in range(batch_size):
|
||||
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
|
||||
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
|
||||
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
best_index = best_hyp_tuple[2]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append hyp to lists
|
||||
best.append(best_hyp)
|
||||
|
||||
# append indices to list
|
||||
best_indices.append(best_index)
|
||||
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_lengths_max = sent_lengths.max().item() + 1
|
||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
|
||||
if len(best_indices) > 0 and best_indices[0] is not None:
|
||||
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
else:
|
||||
indices = None
|
||||
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
if pad_token_id is None:
|
||||
raise ValueError("`pad_token_id` has to be defined")
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
if indices is not None:
|
||||
indices.fill_(-1)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
|
||||
if indices is not None:
|
||||
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
# inserting only the first eos_token_id
|
||||
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
"beam_indices": indices,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
r"""
|
||||
[`BeamScorer`] implementing constrained beam search decoding.
|
||||
@ -163,6 +465,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
logger.warning_once(
|
||||
"`ConstrainedBeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
|
||||
)
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
@ -611,6 +916,9 @@ class BeamHypotheses:
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
logger.warning_once(
|
||||
"`BeamHypotheses` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
|
||||
)
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.max_length = max_length
|
||||
|
@ -89,7 +89,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
|
||||
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
|
||||
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
|
||||
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
||||
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
@ -202,18 +201,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
bad_words_ids (`list[list[int]]`, *optional*):
|
||||
List of list of token ids that are not allowed to be generated. Check
|
||||
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
|
||||
force_words_ids (`list[list[int]]` or `list[list[list[int]]]`, *optional*):
|
||||
List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
|
||||
words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
|
||||
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
||||
can allow different forms of each word.
|
||||
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether to renormalize the logits after applying all the logits processors (including the custom
|
||||
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
||||
are normalized but some logit processors break the normalization.
|
||||
constraints (`list[Constraint]`, *optional*):
|
||||
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
||||
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
||||
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
|
||||
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
||||
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
||||
@ -374,9 +365,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
||||
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
|
||||
self.constraints = kwargs.pop("constraints", None)
|
||||
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
||||
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
||||
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
||||
@ -434,6 +423,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.dola_layers = kwargs.pop("dola_layers", None)
|
||||
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||
self.constraints = kwargs.pop("constraints", None)
|
||||
self.force_words_ids = kwargs.pop("force_words_ids", None)
|
||||
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
||||
# interface.
|
||||
@ -625,24 +616,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="length_penalty", flag_value=self.length_penalty
|
||||
)
|
||||
if self.constraints is not None:
|
||||
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="constraints", flag_value=self.constraints
|
||||
)
|
||||
|
||||
# 2.3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
||||
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
||||
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
|
||||
# 2.4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
|
@ -1441,6 +1441,133 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] that enforces diverse beam search.
|
||||
Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam
|
||||
Search: Decoding Diverse Solutions from Neural Sequence Models](https://huggingface.co/papers/1610.02424) for more
|
||||
details.
|
||||
Traditional beam search often generates very similar sequences across different beams.
|
||||
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
|
||||
beams in the same time step.
|
||||
Args:
|
||||
diversity_penalty (`float`):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
||||
particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
|
||||
this value can help strike a balance between diversity and natural likelihood.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
num_beam_groups (`int`):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
[this paper](https://huggingface.co/papers/1610.02424) for more details.
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
>>> import torch
|
||||
>>> # Initialize the model and tokenizer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
|
||||
>>> # A long text about the solar system
|
||||
>>> text = (
|
||||
... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
|
||||
... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
|
||||
... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
|
||||
... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
|
||||
... "interstellar molecular cloud."
|
||||
... )
|
||||
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
|
||||
>>> # Generate diverse summary
|
||||
>>> outputs_diverse = model.generate(
|
||||
... **inputs,
|
||||
... num_beam_groups=2,
|
||||
... diversity_penalty=10.0,
|
||||
... max_length=100,
|
||||
... num_beams=4,
|
||||
... num_return_sequences=2,
|
||||
... )
|
||||
>>> summaries_diverse = tokenizer.batch_decode(outputs_diverse, skip_special_tokens=True)
|
||||
>>> # Generate non-diverse summary
|
||||
>>> outputs_non_diverse = model.generate(
|
||||
... **inputs,
|
||||
... max_length=100,
|
||||
... num_beams=4,
|
||||
... num_return_sequences=2,
|
||||
... )
|
||||
>>> summary_non_diverse = tokenizer.batch_decode(outputs_non_diverse, skip_special_tokens=True)
|
||||
>>> # With `diversity_penalty`, the resulting beams are much more diverse
|
||||
>>> print(summary_non_diverse)
|
||||
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
|
||||
'the Solar System formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.']
|
||||
>>> print(summaries_diverse)
|
||||
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
|
||||
'the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets. the rest of the objects are smaller objects, such as the five dwarf planets and small solar system bodies.']
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
|
||||
logger.warning_once(
|
||||
"`HammingDiversityLogitsProcessor` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
|
||||
)
|
||||
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
|
||||
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
|
||||
self._diversity_penalty = diversity_penalty
|
||||
if not isinstance(num_beams, int) or num_beams < 2:
|
||||
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
||||
self._num_beams = num_beams
|
||||
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
||||
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
|
||||
self._num_sub_beams = num_beams // num_beam_groups
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
current_tokens: torch.LongTensor,
|
||||
beam_group_idx: int,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||||
beam search or log softmax for each vocabulary token when using beam search
|
||||
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
|
||||
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
|
||||
beam groups in the current generation step.
|
||||
beam_group_idx (`int`):
|
||||
The index of the beam group currently being processed.
|
||||
Return:
|
||||
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
|
||||
The processed prediction scores.
|
||||
"""
|
||||
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
||||
# the same time step
|
||||
batch_size = current_tokens.shape[0] // self._num_beams
|
||||
group_start_idx = beam_group_idx * self._num_sub_beams
|
||||
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
vocab_size = scores.shape[-1]
|
||||
|
||||
if group_start_idx == 0:
|
||||
return scores
|
||||
|
||||
scores_processed = scores.clone()
|
||||
for batch_idx in range(batch_size):
|
||||
# predicted tokens of last time step of previous groups
|
||||
previous_group_tokens = current_tokens[
|
||||
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
|
||||
]
|
||||
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
|
||||
scores_processed[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
|
||||
self._diversity_penalty * token_frequency
|
||||
)
|
||||
|
||||
return scores_processed
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
|
||||
|
@ -33,7 +33,6 @@ from ..cache_utils import (
|
||||
QuantizedCache,
|
||||
StaticCache,
|
||||
)
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..dynamic_module_utils import (
|
||||
check_python_requirements,
|
||||
get_cached_module_file,
|
||||
@ -53,8 +52,6 @@ from ..utils import (
|
||||
is_torchdynamo_exporting,
|
||||
logging,
|
||||
)
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import (
|
||||
AssistantVocabTranslatorCache,
|
||||
AssistedCandidateGenerator,
|
||||
@ -370,7 +367,6 @@ class GenerationMixin(ContinuousMixin):
|
||||
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
|
||||
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
|
||||
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
|
||||
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
|
||||
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
@ -2127,6 +2123,39 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
return can_compile
|
||||
|
||||
def _get_deprecated_gen_repo(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
trust_remote_code: bool,
|
||||
custom_generate: Optional[str] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns the Hub repo for a deprecated generation strategy, if any.
|
||||
"""
|
||||
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||
moved_to_hub_modes = {
|
||||
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
|
||||
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
|
||||
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
|
||||
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
|
||||
}
|
||||
if custom_generate is not None or generation_mode not in moved_to_hub_modes:
|
||||
return None
|
||||
|
||||
repo = moved_to_hub_modes[generation_mode]
|
||||
logger.warning_once(
|
||||
f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
|
||||
f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
|
||||
"to your `generate` call before v4.62.0."
|
||||
)
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"{generation_mode.name.replace('_', ' ').title()} requires `trust_remote_code=True` in your `generate` call, "
|
||||
f"since it loads https://hf.co/{repo}."
|
||||
)
|
||||
return repo
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -2243,6 +2272,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
"""
|
||||
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
if custom_generate is not None and isinstance(custom_generate, str):
|
||||
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
|
||||
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
|
||||
@ -2272,6 +2302,33 @@ class GenerationMixin(ContinuousMixin):
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
||||
|
||||
# Deprecation-related step: set Hub repo for deprecated strategies.
|
||||
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
|
||||
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
|
||||
# TODO joao, manuel: remove this in v4.62.0
|
||||
if deprecate_mode_repo := self._get_deprecated_gen_repo(
|
||||
generation_config, trust_remote_code, custom_generate, assistant_model
|
||||
):
|
||||
return GenerationMixin.generate(
|
||||
self,
|
||||
inputs,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
stopping_criteria,
|
||||
prefix_allowed_tokens_fn,
|
||||
synced_gpus,
|
||||
assistant_model,
|
||||
streamer,
|
||||
negative_prompt_ids,
|
||||
negative_prompt_attention_mask,
|
||||
use_model_defaults,
|
||||
custom_generate=deprecate_mode_repo,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer=tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
if synced_gpus is None:
|
||||
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
||||
@ -2482,47 +2539,6 @@ class GenerationMixin(ContinuousMixin):
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
# TODO joao, manuel: remove this in v4.62.0
|
||||
elif generation_mode == GenerationMode.DOLA_GENERATION:
|
||||
logger.warning_once(
|
||||
"DoLa generation was moved to a `custom_generate` repo: https://hf.co/transformers-community/dola. "
|
||||
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/dola'` "
|
||||
"to your `generate` call before v4.62.0."
|
||||
)
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
"DoLa generation requires `trust_remote_code=True` in your `generate` call, since "
|
||||
"it loads https://hf.co/transformers-community/dola."
|
||||
)
|
||||
return GenerationMixin.generate(
|
||||
self,
|
||||
inputs,
|
||||
custom_generate="transformers-community/dola",
|
||||
generation_config=generation_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
# TODO joao, manuel: remove this in v4.62.0
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
logger.warning_once(
|
||||
"Contrastive search was moved to a `custom_generate` repo: https://hf.co/transformers-community/contrastive-search. "
|
||||
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/contrastive-search'` "
|
||||
"to your `generate` call before v4.62.0."
|
||||
)
|
||||
if not trust_remote_code:
|
||||
logger.warning_once(
|
||||
"Contrastive search requires `trust_remote_code=True` in your `generate` call, since "
|
||||
"it loads https://hf.co/transformers-community/contrastive-search."
|
||||
)
|
||||
# Avoid calling the model-defined `generate` method, since some models (e.g. Janus, Whisper) override it.
|
||||
return GenerationMixin.generate(
|
||||
self,
|
||||
inputs,
|
||||
custom_generate="transformers-community/contrastive-search",
|
||||
generation_config=generation_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||
@ -2547,93 +2563,6 @@ class GenerationMixin(ContinuousMixin):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
|
||||
logger.warning_once(
|
||||
"Group Beam Search was moved to a `custom_generate` repo: https://hf.co/transformers-community/group-beam-search. "
|
||||
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/group-beam-search'` "
|
||||
"to your `generate` call before v4.62.0."
|
||||
)
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
"Group Beam Search requires `trust_remote_code=True` in your `generate` call, since "
|
||||
"it loads https://hf.co/transformers-community/group-beam-search."
|
||||
)
|
||||
return GenerationMixin.generate(
|
||||
self,
|
||||
inputs,
|
||||
custom_generate="transformers-community/group-beam-search",
|
||||
generation_config=generation_config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
|
||||
logger.warning_once(
|
||||
"Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
|
||||
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
|
||||
)
|
||||
final_constraints = []
|
||||
if generation_config.constraints is not None:
|
||||
final_constraints = generation_config.constraints
|
||||
|
||||
if generation_config.force_words_ids is not None:
|
||||
|
||||
def typeerror():
|
||||
raise ValueError(
|
||||
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
|
||||
f"of positive integers, but is {generation_config.force_words_ids}."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(generation_config.force_words_ids, list)
|
||||
or len(generation_config.force_words_ids) == 0
|
||||
):
|
||||
typeerror()
|
||||
|
||||
for word_ids in generation_config.force_words_ids:
|
||||
if isinstance(word_ids[0], list):
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
||||
typeerror()
|
||||
if any(
|
||||
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||
for token_ids in word_ids
|
||||
):
|
||||
typeerror()
|
||||
|
||||
constraint = DisjunctiveConstraint(word_ids)
|
||||
else:
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
||||
typeerror()
|
||||
|
||||
constraint = PhrasalConstraint(word_ids)
|
||||
final_constraints.append(constraint)
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
constrained_beam_scorer = ConstrainedBeamSearchScorer(
|
||||
constraints=final_constraints,
|
||||
batch_size=batch_size,
|
||||
num_beams=generation_config.num_beams,
|
||||
device=inputs_tensor.device,
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# 12. run beam search
|
||||
result = self._constrained_beam_search(
|
||||
input_ids,
|
||||
constrained_beam_scorer=constrained_beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Convert to legacy cache format if requested
|
||||
if (
|
||||
generation_config.return_legacy_cache is True
|
||||
@ -3511,246 +3440,6 @@ class GenerationMixin(ContinuousMixin):
|
||||
else:
|
||||
return sequences
|
||||
|
||||
def _constrained_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **constrained beam search
|
||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
|
||||
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
||||
sorted during generation, while satisfying a list of positive constraints. For more information, the
|
||||
documentation of [`ConstrainedBeamSearchScorer`] should be read.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
||||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
|
||||
batch_size = len(constrained_beam_scorer._beam_hyps)
|
||||
num_beams = constrained_beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape[:2]
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||
beam_indices = (
|
||||
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
|
||||
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# prepare variable output controls (note: some models won't accept all output controls)
|
||||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
||||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
||||
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue
|
||||
|
||||
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
|
||||
scores_for_all_vocab = next_token_scores.clone()
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# reshape for beam search
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
||||
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = (next_tokens / vocab_size).long()
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
scores_for_all_vocab,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
||||
# (that way the memory peak does not include outputs.logits)
|
||||
del outputs
|
||||
|
||||
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
||||
if model_kwargs.get("past_key_values", None) is not None:
|
||||
if hasattr(self, "_reorder_cache"):
|
||||
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
||||
else:
|
||||
model_kwargs["past_key_values"].reorder_cache(beam_idx)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
if self.config.is_encoder_decoder:
|
||||
return GenerateBeamEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return GenerateBeamDecoderOnlyOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def _assisted_decoding(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -4142,52 +3831,3 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
|
||||
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
||||
outputs += (new_tuple,)
|
||||
return outputs
|
||||
|
||||
|
||||
def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
|
||||
"""
|
||||
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
||||
specific ModelOutput subclass from the list provided.
|
||||
"""
|
||||
if not model_outputs:
|
||||
raise ValueError("Input list is empty.")
|
||||
|
||||
# Infer the class from the first object in the list
|
||||
model_output_cls = type(model_outputs[0])
|
||||
|
||||
# Ensure all objects are of the same type
|
||||
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
||||
raise ValueError("All elements in the list should be of the same type.")
|
||||
|
||||
# Helper function to concat tensors or tuples of tensors
|
||||
def _concat(data):
|
||||
"""
|
||||
Reverse of `_split` function above.
|
||||
"""
|
||||
if any(data is None for data in data):
|
||||
return None
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
return torch.cat(data, dim=0)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
return tuple(
|
||||
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
||||
for i in range(len(data[0]))
|
||||
)
|
||||
else:
|
||||
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
||||
elif isinstance(data[0], (int, float)):
|
||||
# If the elements are integers or floats, return a tensor
|
||||
return torch.tensor(data)
|
||||
else:
|
||||
raise TypeError(f"Unexpected attribute type: {type(data[0])}")
|
||||
|
||||
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
|
||||
concatenated_data = {
|
||||
k: _concat([getattr(model_output, k) for model_output in model_outputs])
|
||||
for k in model_output_cls.__dataclass_fields__
|
||||
}
|
||||
|
||||
# Return a new object of the inferred class with the concatenated attributes
|
||||
return model_output_cls(**concatenated_data)
|
||||
|
@ -1,340 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
BeamHypotheses,
|
||||
ConstrainedBeamSearchScorer,
|
||||
DisjunctiveConstraint,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
|
||||
|
||||
class ConstrainedBeamSearchTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
constraints=None,
|
||||
batch_size=3,
|
||||
sequence_length=10,
|
||||
vocab_size=99,
|
||||
pad_token_id=0,
|
||||
max_length=20,
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
do_early_stopping=True,
|
||||
num_beam_hyps_to_keep=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.vocab_size = vocab_size
|
||||
self.pad_token_id = pad_token_id
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
if constraints is None:
|
||||
force_tokens = torch.randint(10, 50, (1, 2))[0].tolist()
|
||||
disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist()
|
||||
|
||||
constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)]
|
||||
self.constraints = constraints
|
||||
# cannot be randomly generated
|
||||
self.eos_token_id = vocab_size + 1
|
||||
|
||||
def prepare_constrained_beam_scorer(self, **kwargs):
|
||||
return ConstrainedBeamSearchScorer(
|
||||
constraints=kwargs.get("constraints", self.constraints),
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
|
||||
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
|
||||
)
|
||||
|
||||
def prepare_inputs(self):
|
||||
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
|
||||
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
|
||||
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
|
||||
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
|
||||
scores_for_all_vocab, _ = (
|
||||
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
|
||||
).sort(descending=True)
|
||||
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
|
||||
|
||||
def check_beam_hypotheses(self, input_ids, *args):
|
||||
# check that correct number of beam hypotheses is set in beam scorer
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
|
||||
beam_hyp = constrained_beam_scorer._beam_hyps[0]
|
||||
|
||||
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
|
||||
|
||||
# check correct type
|
||||
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
|
||||
|
||||
# check that num_beams is correctly set
|
||||
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
|
||||
|
||||
# check for early stopping deactivated
|
||||
for beam_idx in range(self.num_beams):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0)
|
||||
|
||||
# if early stopping True -> score does not matter
|
||||
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
|
||||
|
||||
# re-init
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
|
||||
beam_hyp = constrained_beam_scorer._beam_hyps[0]
|
||||
|
||||
# add `num_beams + 1` beams to change `worst_score`
|
||||
for beam_idx in range(self.num_beams + 1):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
|
||||
|
||||
# -10.0 is removed => -9.0 is worst score
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
|
||||
|
||||
# -5.0 is better than worst score => should not be finished
|
||||
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
|
||||
|
||||
# -20.0 is worse than worst score => should be finished
|
||||
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
|
||||
|
||||
def check_constrained_beam_scorer_update(
|
||||
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
|
||||
):
|
||||
# check too many eos tokens
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
stacked_token_ids = []
|
||||
for constraint in self.constraints:
|
||||
token_ids = constraint.token_ids
|
||||
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||
stacked_token_ids = stacked_token_ids + token_ids
|
||||
|
||||
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[0, :] = self.eos_token_id
|
||||
|
||||
with self.parent.assertRaises(ValueError):
|
||||
constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
|
||||
# check all batches are done
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(constrained_beam_scorer.is_done)
|
||||
|
||||
# check
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
def cut_expected_tensor(tensor):
|
||||
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
|
||||
|
||||
# check all outptus
|
||||
# cut out id of eos token and take best `num_beams` outputs
|
||||
expected_output_tokens = cut_expected_tensor(tokens)
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
offset = torch.div(
|
||||
torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams, rounding_mode="floor"
|
||||
)
|
||||
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
|
||||
def check_constrained_beam_scorer_finalize(
|
||||
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
|
||||
):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
|
||||
# for testing finalize, we do want to have fulfilled constraints
|
||||
stacked_token_ids = []
|
||||
for constraint in self.constraints:
|
||||
token_ids = constraint.token_ids
|
||||
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||
stacked_token_ids = stacked_token_ids + token_ids
|
||||
|
||||
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
constraints = constrained_beam_scorer.constraints
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
# first batch, first output has to finish with eos token id since scores are correctly sorted
|
||||
tokens[0, 0] = self.eos_token_id
|
||||
# make sure corresponding score is as good as possible to surely be picked first
|
||||
next_scores[0, 0] = 0.0
|
||||
|
||||
beam_outputs = constrained_beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
sequence_output = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
|
||||
|
||||
# check sequence_scores
|
||||
self.parent.assertFalse((sequence_scores > 0).any().item())
|
||||
|
||||
# first batch has to finish with eos_token
|
||||
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
|
||||
|
||||
# other batches cannot finish with eos token
|
||||
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# test that the constraint is indeed fulfilled
|
||||
for output, constraint in [(s, c) for s in sequences for c in constraints]:
|
||||
forced_token_ids = constraint.token_ids
|
||||
if isinstance(forced_token_ids[0], list):
|
||||
# disjunctive case
|
||||
flag = False
|
||||
for token_ids in forced_token_ids:
|
||||
if self._check_sequence_inside_sequence(output, token_ids):
|
||||
flag = True
|
||||
break
|
||||
self.parent.assertEqual(flag, True)
|
||||
else:
|
||||
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
|
||||
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
|
||||
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
sequence_output = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
# set to same device. we don't care what device.
|
||||
|
||||
if not isinstance(tensor_1, list):
|
||||
tensor_1 = tensor_1.tolist()
|
||||
if not isinstance(tensor_2, list):
|
||||
tensor_2 = tensor_2.tolist()
|
||||
|
||||
in_order = len(tensor_1) <= len(tensor_2)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = len(shorter)
|
||||
for chunk_idx in range(len(longer) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if subseq == shorter:
|
||||
flag = True
|
||||
break
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConstrainedBeamSearchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
|
||||
|
||||
def test_constrained_beam_hypotheses(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
|
||||
|
||||
def test_constrained_beam_scorer_update(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
|
||||
|
||||
def test_constrained_beam_scorer_finalize(self):
|
||||
inputs = self.constrained_beam_search_tester.prepare_inputs()
|
||||
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)
|
@ -189,14 +189,9 @@ class GenerationConfigTest(unittest.TestCase):
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Impossible sets of constraints/parameters will raise an exception
|
||||
# Impossible sets of parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
with self.assertRaises(ValueError):
|
||||
# dummy constraint
|
||||
GenerationConfig(do_sample=True, num_beams=2, constraints=["dummy"])
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=True, num_beams=2, force_words_ids=[[[1, 2, 3]]])
|
||||
|
||||
# Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
|
@ -40,7 +40,6 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
@ -89,7 +88,6 @@ if is_torch_available():
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
CompileConfig,
|
||||
DisjunctiveConstraint,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
@ -101,7 +99,6 @@ if is_torch_available():
|
||||
LogitsProcessorList,
|
||||
MaxLengthCriteria,
|
||||
MinLengthLogitsProcessor,
|
||||
PhrasalConstraint,
|
||||
PromptLookupCandidateGenerator,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
@ -209,15 +206,6 @@ class GenerationTesterMixin:
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": num_return_sequences * 4,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
def _greedy_generate(
|
||||
self,
|
||||
model,
|
||||
@ -340,38 +328,6 @@ class GenerationTesterMixin:
|
||||
|
||||
return output_generate
|
||||
|
||||
def _constrained_beam_search_generate(
|
||||
self,
|
||||
model,
|
||||
inputs_dict,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
output_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
min_new_tokens=self.max_new_tokens,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
constraints=constraints,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -706,115 +662,6 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@is_flaky() # Some models have position-specific tokens, this test may try to force them in an invalid position
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = config.get_text_config(decoder=True).vocab_size
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs = self._get_constrained_beam_kwargs()
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
# check`constrained_beam_search` for higher than 1 `num_return_sequences`
|
||||
# Sample constraints
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
|
||||
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
@is_flaky() # Some models have position-specific tokens, this test may try to force them in an invalid position
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = model.config.get_text_config(decoder=True).vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs = self._get_constrained_beam_kwargs()
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
@ -2881,120 +2728,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alt bist du?"])
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
PhrasalConstraint(force_tokens_2),
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers were not prepared and"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
constraints=constraints,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
max_length=30,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers were not prepared and didn't know what to do. They had no idea how they would react if"
|
||||
" the enemy attacked them, big weapons scared"
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
flexible_phrases = tokenizer(
|
||||
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
|
||||
).input_ids
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_phrase),
|
||||
DisjunctiveConstraint(flexible_phrases),
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers", "The child"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
constraints=constraints,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
# max_length=20,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers, who had been stationed at the base for more than a year before being evacuated"
|
||||
" screaming scared",
|
||||
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
force_word = "scared"
|
||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||
|
||||
force_words_ids = [
|
||||
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers", "The child"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers, who had been stationed at the base for more than a year before being evacuated"
|
||||
" screaming scared",
|
||||
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_cfg_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
@ -3035,30 +2768,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
force_words = ["sind"]
|
||||
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
|
||||
|
||||
# TODO joao, manuel: remove in v4.62.0
|
||||
@slow
|
||||
def test_constrained_beam_search_example_integration(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
@ -3085,6 +2795,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
force_words_ids=[constraint_token_ids],
|
||||
min_length=5,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
trust_remote_code=True,
|
||||
custom_generate="transformers-community/constrained-beam-search",
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
@ -3128,46 +2840,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(out_text, expected_out)
|
||||
|
||||
def test_constrained_beam_search_mixin_type_checks(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
force_words = ["sind"]
|
||||
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
|
||||
model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
force_words = ["sind"]
|
||||
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
|
||||
model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[-1]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_batched_decoder_start_id(self):
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
@ -4742,6 +4414,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
"length_penalty": 2.0,
|
||||
},
|
||||
),
|
||||
(
|
||||
"transformers-community/constrained-beam-search",
|
||||
{"do_sample": False, "num_beams": 2, "force_words_ids": [[167, 168, 169]]},
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_hub_gen_strategies(self, custom_generate, extra_kwargs):
|
||||
|
@ -272,16 +272,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support constrained beam search.")
|
||||
def test_constrained_beam_search_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support constrained beam search.")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
|
@ -237,7 +237,6 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
skippable_tests = [
|
||||
"test_sample_generate_dict_output", # return sequences > 1
|
||||
"test_beam",
|
||||
"test_constrained_beam",
|
||||
"test_contrastive",
|
||||
"test_assisted",
|
||||
"test_prompt_lookup",
|
||||
|
@ -430,10 +430,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do constraint generation, has custom `generate()`")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot generate from inputs embeds")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
@ -128,21 +128,11 @@ class RecurrentGemmaModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||
@pytest.mark.generate
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
|
@ -387,13 +387,6 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
super().test_beam_search_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
self.has_attentions = False
|
||||
super().test_constrained_beam_search_generate_dict_output()
|
||||
self.has_attentions = old_has_attentions
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
# This model has a custom attention output shape AND config flags, let's skip those checks
|
||||
old_has_attentions = self.has_attentions
|
||||
|
@ -404,12 +404,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
return beam_kwargs
|
||||
|
||||
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
return beam_kwargs
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = WhisperModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
|
Reference in New Issue
Block a user