mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[full_graph] Fix query_start_loc padding (#19321)
Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai>
This commit is contained in:
@ -655,7 +655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
self.query_start_loc[num_reqs + 1:].fill_(
|
||||
self.query_start_loc_cpu[num_reqs].item())
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
Reference in New Issue
Block a user