[BugFix] Fix MinPLogitsProcessor.update_states() (#23401)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-22 17:22:11 -07:00
committed by GitHub
parent c80c53a30f
commit add1adfec7

View File

@ -53,29 +53,37 @@ class MinPLogitsProcessor(LogitsProcessor):
# Process added requests.
for index, params, _, _ in batch_update.added:
min_p = params.min_p
if self.min_p_cpu[index] != min_p:
min_p_before = self.min_p_cpu[index]
if min_p_before != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
if min_p:
if min_p and not min_p_before:
self.min_p_count += 1
elif not min_p and min_p_before:
self.min_p_count -= 1
if self.min_p_count:
# Process removed requests.
needs_update |= bool(batch_update.removed)
if batch_update.removed:
needs_update = True
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_cpu[index] = 0
self.min_p_count -= 1
# Process moved requests, unidirectional (a->b) and swap (a<->b)
# Process moved requests, unidirectional (a->b) and swap (a<->b).
for adx, bdx, direct in batch_update.moved:
change = (min_p_a :=
self.min_p_cpu[adx]) != (min_p_b :=
self.min_p_cpu[bdx])
needs_update |= change
if change:
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
if min_p_a != min_p_b:
needs_update = True
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
self.min_p_cpu[adx] = min_p_b
if direct == MoveDirectionality.UNIDIRECTIONAL:
if min_p_a:
self.min_p_cpu[adx] = 0
if min_p_b:
self.min_p_count -= 1
# Update tensors if needed.
size = batch_update.batch_size