mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
fixes
This commit is contained in:
@ -504,4 +504,6 @@ def ensure_decodes_first(b: InputBatch):
|
||||
break
|
||||
|
||||
# Swap
|
||||
print("Swapping first_prompt_index = {} with last_decode_index = {}".
|
||||
format(first_prompt_index, last_decode_index))
|
||||
swap_positions(b, first_prompt_index, last_decode_index)
|
||||
|
@ -218,10 +218,9 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
block_table = block_table_cpu.reshape(1, -1).to(
|
||||
self.device) if block_table_cpu is not None else None
|
||||
|
||||
context_lens = self.prompt_context_lens_cpu.reshape(1,
|
||||
-1).to(self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu.reshape(
|
||||
1, -1).to(self.device)
|
||||
context_lens = self.prompt_context_lens_cpu.to(self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
|
||||
self.device)
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
@ -247,6 +246,15 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
padded_batch_size = _get_padded_batch_size(batch_size)
|
||||
assert padded_batch_size <= self.max_model_len
|
||||
|
||||
# Init [0 .. batch_size - 1]
|
||||
req_indices_np = self.arange_np[:padded_batch_size]
|
||||
|
||||
print("_prepare_decode:")
|
||||
print(" batch_size = {}".format(batch_size))
|
||||
print(" padded_batch_size = {}".format(padded_batch_size))
|
||||
print(" req_indices_np.shape = {} val = {}".format(
|
||||
req_indices_np.shape, req_indices_np))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
@ -255,29 +263,61 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
input_positions_np[batch_size:] = 0
|
||||
input_positions_cpu = self.input_positions_cpu[:padded_batch_size]
|
||||
|
||||
print(" input_positions_cpu.shape = {} data = {}".format(
|
||||
input_positions_cpu.shape, input_positions_cpu))
|
||||
|
||||
# Input tokens
|
||||
token_indices_np = (
|
||||
input_positions_np +
|
||||
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
|
||||
input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor,
|
||||
1,
|
||||
input_positions_cpu,
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_np),
|
||||
out=input_tokens_cpu)
|
||||
input_tokens_cpu[:batch_size] = 0
|
||||
input_tokens_cpu[batch_size:] = 0
|
||||
|
||||
print(" token_indices_np.shape = {} val = {}".format(
|
||||
token_indices_np.shape, token_indices_np))
|
||||
|
||||
print(" input_tokens_cpu.shape = {} data = {}".format(
|
||||
input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Slot mapping
|
||||
block_table_indices_np = (
|
||||
req_indices_np * self.max_num_blocks_per_req +
|
||||
input_positions_np // self.block_size)
|
||||
|
||||
print(
|
||||
" block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
|
||||
.format(block_table_indices_np.shape, block_table_indices_np,
|
||||
self.max_num_blocks_per_req))
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_numbers_cpu = torch.index_select(
|
||||
block_table_cpu, 1, input_positions_cpu // self.block_size)
|
||||
block_numbers_np = block_numbers_cpu.numpy()
|
||||
|
||||
print(" block_table_cpu.shape = {} data = {}".format(
|
||||
block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
|
||||
|
||||
block_numbers_np = block_table_cpu.flatten(
|
||||
)[block_table_indices_np].numpy()
|
||||
|
||||
print(" block_numbers_np.shape = {} data = {}".format(
|
||||
block_numbers_np.shape, block_numbers_np))
|
||||
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
print(" block_offsets_np.shape = {} data = {}".format(
|
||||
block_offsets_np.shape, block_offsets_np))
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[:padded_batch_size]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[:, batch_size:] = _PAD_SLOT_ID
|
||||
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
|
||||
|
||||
block_table_cpu = block_table_cpu[:len(decode_req_ids)]
|
||||
print(" slot_mapping_np.shape = {} data = {}".format(
|
||||
slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
block_table_cpu = block_table_cpu[:padded_batch_size]
|
||||
|
||||
# Context lens
|
||||
context_lens_np = self.decode_context_lens_np[:padded_batch_size]
|
||||
@ -287,14 +327,17 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
context_lens_np[batch_size:] = 0
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.to(self.device)
|
||||
input_positions = input_positions_cpu.to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[:padded_batch_size].to(
|
||||
self.device)
|
||||
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
|
||||
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[:padded_batch_size].reshape(
|
||||
-1, 1).to(self.device)
|
||||
block_table = block_table_cpu.to(self.device)
|
||||
context_lens = self.decode_context_lens_cpu[:padded_batch_size].to(
|
||||
self.device)
|
||||
|
||||
print(" context_lens.shape = {} val = {}".format(
|
||||
context_lens.shape, context_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=0,
|
||||
|
Reference in New Issue
Block a user