Compare commits

...

1 Commits

View File

@ -549,15 +549,15 @@ class TFForceTokensLogitsProcessor(TFLogitsProcessor):
`-inf` so that they are sampled at their corresponding index."""
def __init__(self, force_token_map: List[List[int]]):
force_token_map = dict(force_token_map)
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
# Indexes without forced tokens will have an negative value.
force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
for index, token in force_token_map.items():
if token is not None:
force_token_array[index] = token
self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
force_token_map = tf.convert_to_tensor(force_token_map)
force_max_token_id = tf.reduce_max(force_token_map[:, 0], axis=0)
tf_token_array_ = tf.fill([force_max_token_id + 1], -1)
self.force_token_array = tf.tensor_scatter_nd_update(
tf_token_array_,
force_token_map[:, 0, tf.newaxis],
force_token_map[:, 1]
)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
def _force_token(generation_idx):