mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
||||
deprecate_kwargs, get_open_port, memory_profiling,
|
||||
merge_async_iterators, supports_kw)
|
||||
merge_async_iterators, supports_kw, swap_dict_values)
|
||||
|
||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||
|
||||
@ -449,3 +449,26 @@ def test_placeholder_module_error_handling():
|
||||
with build_ctx():
|
||||
# Test conflict with internal __module attribute
|
||||
_ = placeholder_attr.module
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
# Tests for both keys exist
|
||||
({1: "a", 2: "b"}, 1, 2),
|
||||
# Tests for one key does not exist
|
||||
({1: "a", 2: "b"}, 1, 3),
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
])
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
swap_dict_values(obj, key1, key2)
|
||||
if key1 in original_obj:
|
||||
assert obj[key2] == original_obj[key1]
|
||||
else:
|
||||
assert key2 not in obj
|
||||
if key2 in original_obj:
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
|
@ -42,6 +42,7 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
|
||||
|
||||
|
@ -77,6 +77,49 @@ def _create_allowed_token_ids(
|
||||
return mask
|
||||
|
||||
|
||||
def _create_bad_words_token_ids(
|
||||
batch_size: int, vocab_size: int,
|
||||
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
|
||||
bad_words_token_ids = {}
|
||||
for batch_idx in range(batch_size):
|
||||
token_ids_single_batch = []
|
||||
for bad_words_length in bad_words_lengths:
|
||||
token_ids = np.random.choice(vocab_size,
|
||||
size=bad_words_length,
|
||||
replace=True).tolist()
|
||||
token_ids_single_batch.append(token_ids)
|
||||
bad_words_token_ids[batch_idx] = token_ids_single_batch
|
||||
if batch_size >= 2:
|
||||
# Test no bad_words for some batch
|
||||
no_bad_words_batch_idx = np.random.choice(batch_size)
|
||||
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
|
||||
return bad_words_token_ids
|
||||
|
||||
|
||||
def _update_output_token_ids_for_bad_words(
|
||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||
bad_words_last_tokens = {}
|
||||
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
|
||||
output_token_ids = metadata.output_token_ids[batch_idx]
|
||||
bad_words_last_token: list[int] = []
|
||||
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
|
||||
if len(bad_word_token_ids) == 1:
|
||||
# Single token id always affects logits
|
||||
bad_words_last_token.append(bad_word_token_ids[0])
|
||||
else:
|
||||
prefix_length = len(bad_word_token_ids) - 1
|
||||
has_bad_words = np.random.choice([True, False])
|
||||
if has_bad_words:
|
||||
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
|
||||
bad_words_last_token.append(bad_word_token_ids[-1])
|
||||
break # Maximum one update to output_token_ids
|
||||
else: # Make sure no accidental match to bad words
|
||||
output_token_ids[-1] = (bad_word_token_ids[-2] +
|
||||
1) % vocab_size
|
||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||
return bad_words_last_tokens
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(device: str, batch_size: int,
|
||||
bad_words_lengths: list[tuple[int]]):
|
||||
"""
|
||||
Test to verify that when the bad words restriction is present, tokens
|
||||
are penalized based on their match with the bad words.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths)
|
||||
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
|
||||
sampling_metadata, VOCAB_SIZE)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if (batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]):
|
||||
assert logits_for_req[token_id] == -float("inf")
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
@ -120,8 +120,22 @@ def test_detokenize_false(model):
|
||||
def test_bad_words(model):
|
||||
"""Check that we respect bad words."""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
|
||||
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||
split_text = output[0].outputs[0].text.split()
|
||||
|
||||
bad_words_1 = " ".join(split_text[:2])
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
|
||||
output = model.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
|
||||
bad_words_2 = new_text.split()[-1]
|
||||
params = SamplingParams(temperature=0,
|
||||
bad_words=[bad_words_1, bad_words_2])
|
||||
output = model.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
assert bad_words_2 not in new_text
|
||||
|
||||
|
||||
def test_logits_processor(model):
|
||||
|
@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
bad_words_token_ids = {}
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@ -284,6 +288,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
assert expected_sampling_metadata.bad_words_token_ids == \
|
||||
sampling_metadata.bad_words_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
|
@ -11,6 +11,8 @@ from pydantic import BaseModel
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -202,7 +204,6 @@ class SamplingParams(
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
bad_words: Optional[list[str]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
@ -232,6 +233,10 @@ class SamplingParams(
|
||||
allowed_token_ids: Optional[list[int]] = None
|
||||
extra_args: Optional[dict[str, Any]] = None
|
||||
|
||||
# Fields used for bad words
|
||||
bad_words: Optional[list[str]] = None
|
||||
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
@ -464,6 +469,46 @@ class SamplingParams(
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
|
||||
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
if self.bad_words is None:
|
||||
return
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
text=prompt, add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space and prompt_token_ids[0]
|
||||
!= self._bad_words_token_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(
|
||||
self._bad_words_token_ids[-1])):
|
||||
self._bad_words_token_ids.append(prompt_token_ids)
|
||||
|
||||
invalid_token_ids = [
|
||||
token_id for bad_words_token_ids in self._bad_words_token_ids
|
||||
for token_id in bad_words_token_ids
|
||||
if token_id < 0 or token_id > tokenizer.max_token_id
|
||||
]
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {tokenizer.max_token_id+1},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id <= {tokenizer.max_token_id}.")
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
@ -476,6 +521,11 @@ class SamplingParams(
|
||||
def all_stop_token_ids(self) -> set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> list[list[int]]:
|
||||
# For internal use only. Backward compatibility not guaranteed
|
||||
return self._bad_words_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||
|
||||
|
@ -2361,3 +2361,19 @@ class LazyLoader(types.ModuleType):
|
||||
if self._module is None:
|
||||
self._module = self._load()
|
||||
return dir(self._module)
|
||||
|
||||
|
||||
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||
"""
|
||||
Helper function to swap values for two keys
|
||||
"""
|
||||
v1 = obj.get(key1)
|
||||
v2 = obj.get(key2)
|
||||
if v1 is not None:
|
||||
obj[key2] = v1
|
||||
else:
|
||||
obj.pop(key2, None)
|
||||
if v2 is not None:
|
||||
obj[key1] = v2
|
||||
else:
|
||||
obj.pop(key1, None)
|
||||
|
@ -94,9 +94,6 @@ class Processor:
|
||||
# Best of not yet supported.
|
||||
if params.best_of is not None and params.best_of > 1:
|
||||
raise ValueError("VLLM V1 does not yet support best_of.")
|
||||
# Bad words not yet supported.
|
||||
if params.bad_words:
|
||||
raise ValueError("VLLM V1 does not yet support bad_words.")
|
||||
# Logits processors not supported.
|
||||
if params.logits_processors:
|
||||
raise ValueError("VLLM V1 does not support per request "
|
||||
@ -203,6 +200,8 @@ class Processor:
|
||||
sampling_params = params.clone()
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
# Compute MM hashes (if enabled)
|
||||
|
@ -38,3 +38,6 @@ class SamplingMetadata:
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
38
vllm/v1/sample/ops/bad_words.py
Normal file
38
vllm/v1/sample/ops/bad_words.py
Normal file
@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
if prefix_length > 0:
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
else:
|
||||
actual_prefix = []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids,
|
||||
past_tokens_ids[i])
|
@ -6,6 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
@ -38,6 +39,8 @@ class Sampler(nn.Module):
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply allowed token ids.
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
# Apply logits bias.
|
||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
@ -237,3 +240,16 @@ class Sampler(nn.Module):
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
@ -204,6 +205,9 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
@ -320,6 +324,9 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@ -369,6 +376,7 @@ class InputBatch:
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
@ -413,27 +421,9 @@ class InputBatch:
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
g1 = self.generators.get(i1)
|
||||
g2 = self.generators.get(i2)
|
||||
if g1 is not None:
|
||||
self.generators[i2] = g1
|
||||
else:
|
||||
self.generators.pop(i2, None)
|
||||
if g2 is not None:
|
||||
self.generators[i1] = g2
|
||||
else:
|
||||
self.generators.pop(i1, None)
|
||||
|
||||
t1 = self.min_tokens.get(i1)
|
||||
t2 = self.min_tokens.get(i2)
|
||||
if t1 is not None:
|
||||
self.min_tokens[i2] = t1
|
||||
else:
|
||||
self.min_tokens.pop(i2, None)
|
||||
if t2 is not None:
|
||||
self.min_tokens[i1] = t2
|
||||
else:
|
||||
self.min_tokens.pop(i1, None)
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
@ -518,6 +508,10 @@ class InputBatch:
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@ -585,6 +579,7 @@ class InputBatch:
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
|
@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
sampler_output = self.model.sample(logits=logits,
|
||||
sampling_metadata=dummy_metadata)
|
||||
|
Reference in New Issue
Block a user