mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix MinPLogitsProcessor.update_states()
(#23401)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -53,29 +53,37 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
min_p = params.min_p
|
||||
if self.min_p_cpu[index] != min_p:
|
||||
min_p_before = self.min_p_cpu[index]
|
||||
if min_p_before != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
if min_p:
|
||||
self.min_p_count += 1
|
||||
if min_p and not min_p_before:
|
||||
self.min_p_count += 1
|
||||
elif not min_p and min_p_before:
|
||||
self.min_p_count -= 1
|
||||
|
||||
if self.min_p_count:
|
||||
# Process removed requests.
|
||||
needs_update |= bool(batch_update.removed)
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_count -= 1
|
||||
if batch_update.removed:
|
||||
needs_update = True
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_cpu[index] = 0
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b).
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
change = (min_p_a :=
|
||||
self.min_p_cpu[adx]) != (min_p_b :=
|
||||
self.min_p_cpu[bdx])
|
||||
needs_update |= change
|
||||
if change:
|
||||
min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
|
||||
if min_p_a != min_p_b:
|
||||
needs_update = True
|
||||
self.min_p_cpu[bdx] = min_p_a
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
self.min_p_cpu[adx] = min_p_b
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if min_p_a:
|
||||
self.min_p_cpu[adx] = 0
|
||||
if min_p_b:
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Update tensors if needed.
|
||||
size = batch_update.batch_size
|
||||
|
Reference in New Issue
Block a user