🚨 Remove Constrained Beam Search decoding strategy (#40518)

* Squashed remove-constrastive-search

* sweeep ready for tests

* testing...

* whoops

* ops

* tests fix

* tests green, changed handling of deprecated methods

* tests gone after green

* restore and deprecate beam obkects

* restore and deprecate constraint objects

* fix ci

* review
This commit is contained in:
Manuel de Prada Corral
2025-09-01 14:34:48 +02:00
committed by GitHub
parent 564be6d895
commit 8564e210ca
19 changed files with 516 additions and 1234 deletions

View File

@ -39,7 +39,6 @@
| [كيفية ضبط نموذج بدقة على التلخيص](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb)| يوضح كيفية معالجة البيانات مسبقًا وضبط نموذج مُدرَّب مسبقًا بدقة على XSUM. | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)|
| [كيفية تدريب نموذج لغة من البداية](https://github.com/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| تسليط الضوء على جميع الخطوات لتدريب نموذج Transformer بشكل فعال على بيانات مخصصة | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)|
| [كيفية إنشاء نص](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| كيفية استخدام أساليب فك التشفير المختلفة لإنشاء اللغة باستخدام المحولات | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)|
| [كيفية إنشاء نص (مع قيود)](https://github.com/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| كيفية توجيه إنشاء اللغة باستخدام القيود التي يوفرها المستخدم | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| كيف يدفع Reformer حدود النمذجة اللغوية | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
#### رؤية الكمبيوتر[[pytorch-cv]]

View File

@ -300,28 +300,6 @@ generation_output[:2]
[[autodoc]] MaxTimeCriteria
- __call__
## Constraints
[`Constraint`] を使用すると、生成時に出力に特定のトークンまたはシーケンスが含まれるように強制できます。これは PyTorch 実装でのみ利用可能であることに注意してください。
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## BeamSearch
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Streamers
[[autodoc]] TextStreamer

View File

@ -305,28 +305,6 @@ generation_output[:2]
[[autodoc]] EosTokenCriteria
- __call__
## Constraint [[transformers.Constraint]]
[`Constraint`]는 생성 출력에 특정 토큰이나 시퀀스를 강제로 포함시키는 데 사용됩니다. 이 기능은 PyTorch 구현에만 제공됩니다.
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## 빔 검색 (BeamSearch) [[transformers.BeamScorer]]
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## 스트리머 (Streamers) [[transformers.TextStreamer]]
[[autodoc]] TextStreamer

View File

@ -295,28 +295,6 @@ generation_output[:2]
[[autodoc]] MaxTimeCriteria
- __call__
## Constraints
可以使用[`Constraint`]来强制生成结果包含输出中的特定tokens或序列。请注意这仅适用于我们的PyTorch实现。
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## BeamSearch
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Streamers
[[autodoc]] TextStreamer

View File

@ -56,7 +56,6 @@ You can open any page of the documentation as a notebook in Colab (there is a bu
| [How to fine-tune a model on summarization](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb)| Show how to preprocess the data and fine-tune a pretrained model on XSUM. | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)|
| [How to train a language model from scratch](https://github.com/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)|
| [How to generate text](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)|
| [How to generate text (with constraints)](https://github.com/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| How to guide language generation with user-provided constraints | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| How Reformer pushes the limits of language modeling | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
#### Computer Vision[[pytorch-cv]]

View File

@ -1,7 +1,13 @@
from abc import ABC, abstractmethod
from typing import Optional
from ..utils import logging
logger = logging.get_logger(__name__)
# TODO joao, manuel: remove in v4.58.0
class Constraint(ABC):
r"""Abstract base class for all constraints that can be applied during generation.
It must define how the constraint can be satisfied.
@ -18,6 +24,9 @@ class Constraint(ABC):
"""
def __init__(self):
logger.warning_once(
"Importing `Constraint` classes is deprecated and will be removed in v4.58.0. Constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search. Please import using `from transformers.generation import Constraint` instead."
)
# test for the above condition
self.test()

View File

@ -20,10 +20,13 @@ from typing import Optional, Union
import numpy as np
import torch
from ..utils import add_start_docstrings
from ..utils import add_start_docstrings, logging
from .beam_constraints import Constraint, ConstraintListState
logger = logging.get_logger(__name__)
PROCESS_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
@ -118,6 +121,305 @@ class BeamScorer(ABC):
raise NotImplementedError("This is an abstract method.")
class BeamSearchScorer(BeamScorer):
r"""
[`BeamScorer`] implementing standard beam search decoding.
Adapted in part from [Facebook's XLM beam search
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
num_beams (`int`):
Number of beams for beam search.
device (`torch.device`):
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
allocated.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformers.BeamSearchScorer.finalize`].
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://huggingface.co/papers/1610.02424) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def __init__(
self,
batch_size: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
max_length: Optional[int] = None,
):
logger.warning_once(
"`BeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
)
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
self._beam_hyps = [
BeamHypotheses(
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size * self.num_beam_groups)
]
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
# in the i-th mini-batch is complete.
self._done = torch.tensor(
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
" one should make use of `greedy_search` instead."
)
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups
if batch_size != (input_ids.shape[0] // self.group_size):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
if self._done[batch_group_idx]:
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_group_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for index_per_group in range(self.group_size):
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i in range(batch_size):
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
if pad_token_id is None:
raise ValueError("`pad_token_id` has to be defined")
decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
class ConstrainedBeamSearchScorer(BeamScorer):
r"""
[`BeamScorer`] implementing constrained beam search decoding.
@ -163,6 +465,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
num_beam_hyps_to_keep: Optional[int] = 1,
max_length: Optional[int] = None,
):
logger.warning_once(
"`ConstrainedBeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
)
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
@ -611,6 +916,9 @@ class BeamHypotheses:
"""
Initialize n-best list of hypotheses.
"""
logger.warning_once(
"`BeamHypotheses` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
)
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.max_length = max_length

View File

@ -89,7 +89,6 @@ class GenerationConfig(PushToHubMixin):
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@ -202,18 +201,10 @@ class GenerationConfig(PushToHubMixin):
bad_words_ids (`list[list[int]]`, *optional*):
List of list of token ids that are not allowed to be generated. Check
[`~generation.NoBadWordsLogitsProcessor`] for further documentation and examples.
force_words_ids (`list[list[int]]` or `list[list[list[int]]]`, *optional*):
List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of
words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors break the normalization.
constraints (`list[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
forced_bos_token_id (`int`, *optional*, defaults to `model.config.forced_bos_token_id`):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
@ -374,9 +365,7 @@ class GenerationConfig(PushToHubMixin):
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.force_words_ids = kwargs.pop("force_words_ids", None)
self.renormalize_logits = kwargs.pop("renormalize_logits", False)
self.constraints = kwargs.pop("constraints", None)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
@ -434,6 +423,8 @@ class GenerationConfig(PushToHubMixin):
self.dola_layers = kwargs.pop("dola_layers", None)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.constraints = kwargs.pop("constraints", None)
self.force_words_ids = kwargs.pop("force_words_ids", None)
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface.
@ -625,24 +616,6 @@ class GenerationConfig(PushToHubMixin):
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="length_penalty", flag_value=self.length_penalty
)
if self.constraints is not None:
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
flag_name="constraints", flag_value=self.constraints
)
# 2.3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
# 2.4. check `num_return_sequences`
if self.num_return_sequences != 1:

View File

@ -1441,6 +1441,133 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
return scores_processed
class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces diverse beam search.
Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam
Search: Decoding Diverse Solutions from Neural Sequence Models](https://huggingface.co/papers/1610.02424) for more
details.
Traditional beam search often generates very similar sequences across different beams.
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
beams in the same time step.
Args:
diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
this value can help strike a balance between diversity and natural likelihood.
num_beams (`int`):
Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
[this paper](https://huggingface.co/papers/1610.02424) for more details.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> import torch
>>> # Initialize the model and tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # A long text about the solar system
>>> text = (
... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
... "interstellar molecular cloud."
... )
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
>>> # Generate diverse summary
>>> outputs_diverse = model.generate(
... **inputs,
... num_beam_groups=2,
... diversity_penalty=10.0,
... max_length=100,
... num_beams=4,
... num_return_sequences=2,
... )
>>> summaries_diverse = tokenizer.batch_decode(outputs_diverse, skip_special_tokens=True)
>>> # Generate non-diverse summary
>>> outputs_non_diverse = model.generate(
... **inputs,
... max_length=100,
... num_beams=4,
... num_return_sequences=2,
... )
>>> summary_non_diverse = tokenizer.batch_decode(outputs_non_diverse, skip_special_tokens=True)
>>> # With `diversity_penalty`, the resulting beams are much more diverse
>>> print(summary_non_diverse)
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
'the Solar System formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.']
>>> print(summaries_diverse)
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
'the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets. the rest of the objects are smaller objects, such as the five dwarf planets and small solar system bodies.']
```
"""
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
logger.warning_once(
"`HammingDiversityLogitsProcessor` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
)
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
self._num_sub_beams = num_beams // num_beam_groups
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
current_tokens: torch.LongTensor,
beam_group_idx: int,
) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
beam groups in the current generation step.
beam_group_idx (`int`):
The index of the beam group currently being processed.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams
group_start_idx = beam_group_idx * self._num_sub_beams
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
group_size = group_end_idx - group_start_idx
vocab_size = scores.shape[-1]
if group_start_idx == 0:
return scores
scores_processed = scores.clone()
for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
]
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
scores_processed[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
self._diversity_penalty * token_frequency
)
return scores_processed
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder

View File

@ -33,7 +33,6 @@ from ..cache_utils import (
QuantizedCache,
StaticCache,
)
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import (
check_python_requirements,
get_cached_module_file,
@ -53,8 +52,6 @@ from ..utils import (
is_torchdynamo_exporting,
logging,
)
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
@ -370,7 +367,6 @@ class GenerationMixin(ContinuousMixin):
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@ -2127,6 +2123,39 @@ class GenerationMixin(ContinuousMixin):
return can_compile
def _get_deprecated_gen_repo(
self,
generation_config: GenerationConfig,
trust_remote_code: bool,
custom_generate: Optional[str] = None,
assistant_model: Optional["PreTrainedModel"] = None,
) -> Optional[str]:
"""
Returns the Hub repo for a deprecated generation strategy, if any.
"""
generation_mode = generation_config.get_generation_mode(assistant_model)
moved_to_hub_modes = {
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
}
if custom_generate is not None or generation_mode not in moved_to_hub_modes:
return None
repo = moved_to_hub_modes[generation_mode]
logger.warning_once(
f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
"to your `generate` call before v4.62.0."
)
if not trust_remote_code:
raise ValueError(
f"{generation_mode.name.replace('_', ' ').title()} requires `trust_remote_code=True` in your `generate` call, "
f"since it loads https://hf.co/{repo}."
)
return repo
@torch.no_grad()
def generate(
self,
@ -2243,6 +2272,7 @@ class GenerationMixin(ContinuousMixin):
"""
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
trust_remote_code = kwargs.pop("trust_remote_code", None)
if custom_generate is not None and isinstance(custom_generate, str):
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
@ -2272,6 +2302,33 @@ class GenerationMixin(ContinuousMixin):
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
# Deprecation-related step: set Hub repo for deprecated strategies.
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
# TODO joao, manuel: remove this in v4.62.0
if deprecate_mode_repo := self._get_deprecated_gen_repo(
generation_config, trust_remote_code, custom_generate, assistant_model
):
return GenerationMixin.generate(
self,
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
assistant_model,
streamer,
negative_prompt_ids,
negative_prompt_attention_mask,
use_model_defaults,
custom_generate=deprecate_mode_repo,
trust_remote_code=trust_remote_code,
tokenizer=tokenizer,
assistant_tokenizer=assistant_tokenizer,
**kwargs,
)
# 2. Set generation parameters if not already defined
if synced_gpus is None:
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
@ -2482,47 +2539,6 @@ class GenerationMixin(ContinuousMixin):
streamer=streamer,
**model_kwargs,
)
# TODO joao, manuel: remove this in v4.62.0
elif generation_mode == GenerationMode.DOLA_GENERATION:
logger.warning_once(
"DoLa generation was moved to a `custom_generate` repo: https://hf.co/transformers-community/dola. "
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/dola'` "
"to your `generate` call before v4.62.0."
)
if not trust_remote_code:
raise ValueError(
"DoLa generation requires `trust_remote_code=True` in your `generate` call, since "
"it loads https://hf.co/transformers-community/dola."
)
return GenerationMixin.generate(
self,
inputs,
custom_generate="transformers-community/dola",
generation_config=generation_config,
trust_remote_code=trust_remote_code,
**kwargs,
)
# TODO joao, manuel: remove this in v4.62.0
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
logger.warning_once(
"Contrastive search was moved to a `custom_generate` repo: https://hf.co/transformers-community/contrastive-search. "
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/contrastive-search'` "
"to your `generate` call before v4.62.0."
)
if not trust_remote_code:
logger.warning_once(
"Contrastive search requires `trust_remote_code=True` in your `generate` call, since "
"it loads https://hf.co/transformers-community/contrastive-search."
)
# Avoid calling the model-defined `generate` method, since some models (e.g. Janus, Whisper) override it.
return GenerationMixin.generate(
self,
inputs,
custom_generate="transformers-community/contrastive-search",
generation_config=generation_config,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
@ -2547,93 +2563,6 @@ class GenerationMixin(ContinuousMixin):
**model_kwargs,
)
elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
logger.warning_once(
"Group Beam Search was moved to a `custom_generate` repo: https://hf.co/transformers-community/group-beam-search. "
"To prevent loss of backward compatibility, add `custom_generate='transformers-community/group-beam-search'` "
"to your `generate` call before v4.62.0."
)
if not trust_remote_code:
raise ValueError(
"Group Beam Search requires `trust_remote_code=True` in your `generate` call, since "
"it loads https://hf.co/transformers-community/group-beam-search."
)
return GenerationMixin.generate(
self,
inputs,
custom_generate="transformers-community/group-beam-search",
generation_config=generation_config,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
logger.warning_once(
"Constrained Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. "
"To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call."
)
final_constraints = []
if generation_config.constraints is not None:
final_constraints = generation_config.constraints
if generation_config.force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` "
f"of positive integers, but is {generation_config.force_words_ids}."
)
if (
not isinstance(generation_config.force_words_ids, list)
or len(generation_config.force_words_ids) == 0
):
typeerror()
for word_ids in generation_config.force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(not isinstance(token_ids, list) for token_ids in word_ids):
typeerror()
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
):
typeerror()
constraint = DisjunctiveConstraint(word_ids)
else:
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
typeerror()
constraint = PhrasalConstraint(word_ids)
final_constraints.append(constraint)
# 11. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
constraints=final_constraints,
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. run beam search
result = self._constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
**model_kwargs,
)
# Convert to legacy cache format if requested
if (
generation_config.return_legacy_cache is True
@ -3511,246 +3440,6 @@ class GenerationMixin(ContinuousMixin):
else:
return sequences
def _constrained_beam_search(
self,
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **constrained beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
The sequence used as a prompt for the generation.
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation, while satisfying a list of positive constraints. For more information, the
documentation of [`ConstrainedBeamSearchScorer`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape[:2]
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
scores_for_all_vocab = next_token_scores.clone()
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
)
next_indices = (next_tokens / vocab_size).long()
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = constrained_beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
scores_for_all_vocab,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
# NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
if model_kwargs.get("past_key_values", None) is not None:
if hasattr(self, "_reorder_cache"):
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
else:
model_kwargs["past_key_values"].reorder_cache(beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))
# increase cur_len
cur_len = cur_len + 1
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
this_peer_finished = True
sequence_outputs = constrained_beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return GenerateBeamEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateBeamDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
logits=raw_logits,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
def _assisted_decoding(
self,
input_ids: torch.LongTensor,
@ -4142,52 +3831,3 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
outputs += (new_tuple,)
return outputs
def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
"""
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
specific ModelOutput subclass from the list provided.
"""
if not model_outputs:
raise ValueError("Input list is empty.")
# Infer the class from the first object in the list
model_output_cls = type(model_outputs[0])
# Ensure all objects are of the same type
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
raise ValueError("All elements in the list should be of the same type.")
# Helper function to concat tensors or tuples of tensors
def _concat(data):
"""
Reverse of `_split` function above.
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
return tuple(
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
for i in range(len(data[0]))
)
else:
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
elif isinstance(data[0], (int, float)):
# If the elements are integers or floats, return a tensor
return torch.tensor(data)
else:
raise TypeError(f"Unexpected attribute type: {type(data[0])}")
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
concatenated_data = {
k: _concat([getattr(model_output, k) for model_output in model_outputs])
for k in model_output_cls.__dataclass_fields__
}
# Return a new object of the inferred class with the concatenated attributes
return model_output_cls(**concatenated_data)

View File

@ -1,340 +0,0 @@
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation import (
BeamHypotheses,
ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
PhrasalConstraint,
)
class ConstrainedBeamSearchTester:
def __init__(
self,
parent,
constraints=None,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2))[0].tolist()
disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist()
constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)]
self.constraints = constraints
# cannot be randomly generated
self.eos_token_id = vocab_size + 1
def prepare_constrained_beam_scorer(self, **kwargs):
return ConstrainedBeamSearchScorer(
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
scores_for_all_vocab, _ = (
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_constrained_beam_scorer_update(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
stacked_token_ids = []
for constraint in self.constraints:
token_ids = constraint.token_ids
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
stacked_token_ids = stacked_token_ids + token_ids
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# check all batches are done
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# beam scorer should be done
self.parent.assertTrue(constrained_beam_scorer.is_done)
# check
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
offset = torch.div(
torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams, rounding_mode="floor"
)
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
def check_constrained_beam_scorer_finalize(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints
stacked_token_ids = []
for constraint in self.constraints:
token_ids = constraint.token_ids
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
stacked_token_ids = stacked_token_ids + token_ids
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
)
constraints = constrained_beam_scorer.constraints
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
for output, constraint in [(s, c) for s in sequences for c in constraints]:
forced_token_ids = constraint.token_ids
if isinstance(forced_token_ids[0], list):
# disjunctive case
flag = False
for token_ids in forced_token_ids:
if self._check_sequence_inside_sequence(output, token_ids):
flag = True
break
self.parent.assertEqual(flag, True)
else:
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
)
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.
if not isinstance(tensor_1, list):
tensor_1 = tensor_1.tolist()
if not isinstance(tensor_2, list):
tensor_2 = tensor_2.tolist()
in_order = len(tensor_1) <= len(tensor_2)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = len(shorter)
for chunk_idx in range(len(longer) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if subseq == shorter:
flag = True
break
return flag
@require_torch
class ConstrainedBeamSearchTest(unittest.TestCase):
def setUp(self):
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
def test_constrained_beam_hypotheses(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
def test_constrained_beam_scorer_update(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
def test_constrained_beam_scorer_finalize(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)

View File

@ -189,14 +189,9 @@ class GenerationConfigTest(unittest.TestCase):
generation_config_bad_temperature.update(temperature=None)
self.assertEqual(len(captured_logs.out), 0)
# Impossible sets of constraints/parameters will raise an exception
# Impossible sets of parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
with self.assertRaises(ValueError):
# dummy constraint
GenerationConfig(do_sample=True, num_beams=2, constraints=["dummy"])
with self.assertRaises(ValueError):
GenerationConfig(do_sample=True, num_beams=2, force_words_ids=[[[1, 2, 3]]])
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):

View File

@ -40,7 +40,6 @@ from transformers import (
)
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
require_accelerate,
require_flash_attn,
require_flash_attn_3,
@ -89,7 +88,6 @@ if is_torch_available():
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
CompileConfig,
DisjunctiveConstraint,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
@ -101,7 +99,6 @@ if is_torch_available():
LogitsProcessorList,
MaxLengthCriteria,
MinLengthLogitsProcessor,
PhrasalConstraint,
PromptLookupCandidateGenerator,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
@ -209,15 +206,6 @@ class GenerationTesterMixin:
}
return beam_kwargs
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": num_return_sequences * 4,
"num_return_sequences": num_return_sequences,
}
return beam_kwargs
def _greedy_generate(
self,
model,
@ -340,38 +328,6 @@ class GenerationTesterMixin:
return output_generate
def _constrained_beam_search_generate(
self,
model,
inputs_dict,
constraints,
beam_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
output_generate = model.generate(
do_sample=False,
max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
constraints=constraints,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**inputs_dict,
)
return output_generate
@pytest.mark.generate
def test_greedy_generate(self):
for model_class in self.all_generative_model_classes:
@ -706,115 +662,6 @@ class GenerationTesterMixin:
)
self.assertIsNotNone(output_ids_generate)
@is_flaky() # Some models have position-specific tokens, this test may try to force them in an invalid position
@pytest.mark.generate
def test_constrained_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
# Sample constraints
min_id = 3
max_id = config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
)
if model.config.get_text_config(decoder=True).is_encoder_decoder:
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
# check`constrained_beam_search` for higher than 1 `num_return_sequences`
# Sample constraints
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
output_generate = self._constrained_beam_search_generate(
model=model,
inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
)
if model.config.get_text_config(decoder=True).is_encoder_decoder:
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@is_flaky() # Some models have position-specific tokens, this test may try to force them in an invalid position
@pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval()
# Sample constraints
min_id = 3
max_id = model.config.get_text_config(decoder=True).vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
inputs_dict=inputs_dict,
constraints=constraints,
beam_kwargs=beam_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.get_text_config(decoder=True).is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_generate_outputs(
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
@ -2881,120 +2728,6 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(outputs, ["Wie alt bist du?"])
@slow
def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
constraints = [
PhrasalConstraint(force_tokens),
PhrasalConstraint(force_tokens_2),
]
starting_text = ["The soldiers were not prepared and"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
max_length=30,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers were not prepared and didn't know what to do. They had no idea how they would react if"
" the enemy attacked them, big weapons scared"
],
)
@slow
def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer(
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
).input_ids
constraints = [
PhrasalConstraint(force_phrase),
DisjunctiveConstraint(flexible_phrases),
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
constraints=constraints,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
# max_length=20,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who had been stationed at the base for more than a year before being evacuated"
" screaming scared",
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
],
)
@slow
def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
force_words_ids = [
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The soldiers, who had been stationed at the base for more than a year before being evacuated"
" screaming scared",
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
],
)
@slow
def test_cfg_mixin(self):
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
@ -3035,30 +2768,7 @@ class GenerationIntegrationTests(unittest.TestCase):
],
)
@slow
def test_constrained_beam_search_example_translation_mixin(self):
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
encoder_input_str = "translate English to German: How old are you?"
force_words = ["sind"]
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
outputs = model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alt sind Sie?"])
# TODO joao, manuel: remove in v4.62.0
@slow
def test_constrained_beam_search_example_integration(self):
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
@ -3085,6 +2795,8 @@ class GenerationIntegrationTests(unittest.TestCase):
force_words_ids=[constraint_token_ids],
min_length=5,
eos_token_id=model.config.eos_token_id,
trust_remote_code=True,
custom_generate="transformers-community/constrained-beam-search",
**model_kwargs,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
@ -3128,46 +2840,6 @@ class GenerationIntegrationTests(unittest.TestCase):
]
self.assertListEqual(out_text, expected_out)
def test_constrained_beam_search_mixin_type_checks(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
force_words = ["sind"]
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
model.generate(
input_ids,
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
)
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[-1]])
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
def test_batched_decoder_start_id(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
@ -4742,6 +4414,10 @@ class GenerationIntegrationTests(unittest.TestCase):
"length_penalty": 2.0,
},
),
(
"transformers-community/constrained-beam-search",
{"do_sample": False, "num_beams": 2, "force_words_ids": [[167, 168, 169]]},
),
]
)
def test_hub_gen_strategies(self, custom_generate, extra_kwargs):

View File

@ -272,16 +272,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
def test_beam_sample_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support constrained beam search.")
def test_constrained_beam_search_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support constrained beam search.")
def test_constrained_beam_search_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
def test_prompt_lookup_decoding_matches_greedy_search(self):

View File

@ -237,7 +237,6 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
skippable_tests = [
"test_sample_generate_dict_output", # return sequences > 1
"test_beam",
"test_constrained_beam",
"test_contrastive",
"test_assisted",
"test_prompt_lookup",

View File

@ -430,10 +430,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Cannot do constraint generation, has custom `generate()`")
def test_constrained_beam_search_generate_dict_output(self):
pass
@unittest.skip("Cannot generate from inputs embeds")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

View File

@ -128,21 +128,11 @@ class RecurrentGemmaModelTest(CausalLMModelTest, unittest.TestCase):
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_constrained_beam_search_generate(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):

View File

@ -387,13 +387,6 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
super().test_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_constrained_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_constrained_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_greedy_generate_dict_outputs(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions

View File

@ -404,12 +404,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def setUp(self):
self.model_tester = WhisperModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)