This commit is contained in:
Alexander Matveev
2025-02-05 15:36:38 +00:00
parent 627efde813
commit 7be649256f
2 changed files with 62 additions and 17 deletions

View File

@ -504,4 +504,6 @@ def ensure_decodes_first(b: InputBatch):
break
# Swap
print("Swapping first_prompt_index = {} with last_decode_index = {}".
format(first_prompt_index, last_decode_index))
swap_positions(b, first_prompt_index, last_decode_index)

View File

@ -218,10 +218,9 @@ class TPUModelRunner(ModelRunnerBase):
block_table = block_table_cpu.reshape(1, -1).to(
self.device) if block_table_cpu is not None else None
context_lens = self.prompt_context_lens_cpu.reshape(1,
-1).to(self.device)
effective_query_lens = self.prompt_effective_query_lens_cpu.reshape(
1, -1).to(self.device)
context_lens = self.prompt_context_lens_cpu.to(self.device)
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
self.device)
# Attn metadata
attn_metadata = PallasMetadata(
@ -247,6 +246,15 @@ class TPUModelRunner(ModelRunnerBase):
padded_batch_size = _get_padded_batch_size(batch_size)
assert padded_batch_size <= self.max_model_len
# Init [0 .. batch_size - 1]
req_indices_np = self.arange_np[:padded_batch_size]
print("_prepare_decode:")
print(" batch_size = {}".format(batch_size))
print(" padded_batch_size = {}".format(padded_batch_size))
print(" req_indices_np.shape = {} val = {}".format(
req_indices_np.shape, req_indices_np))
# Input positions
input_positions_np = self.input_positions_np[:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
@ -255,29 +263,61 @@ class TPUModelRunner(ModelRunnerBase):
input_positions_np[batch_size:] = 0
input_positions_cpu = self.input_positions_cpu[:padded_batch_size]
print(" input_positions_cpu.shape = {} data = {}".format(
input_positions_cpu.shape, input_positions_cpu))
# Input tokens
token_indices_np = (
input_positions_np +
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
torch.index_select(self.input_batch.token_ids_cpu_tensor,
1,
input_positions_cpu,
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_np),
out=input_tokens_cpu)
input_tokens_cpu[:batch_size] = 0
input_tokens_cpu[batch_size:] = 0
print(" token_indices_np.shape = {} val = {}".format(
token_indices_np.shape, token_indices_np))
print(" input_tokens_cpu.shape = {} data = {}".format(
input_tokens_cpu.shape, input_tokens_cpu))
# Slot mapping
block_table_indices_np = (
req_indices_np * self.max_num_blocks_per_req +
input_positions_np // self.block_size)
print(
" block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
.format(block_table_indices_np.shape, block_table_indices_np,
self.max_num_blocks_per_req))
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers_cpu = torch.index_select(
block_table_cpu, 1, input_positions_cpu // self.block_size)
block_numbers_np = block_numbers_cpu.numpy()
print(" block_table_cpu.shape = {} data = {}".format(
block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
block_numbers_np = block_table_cpu.flatten(
)[block_table_indices_np].numpy()
print(" block_numbers_np.shape = {} data = {}".format(
block_numbers_np.shape, block_numbers_np))
block_offsets_np = input_positions_np % self.block_size
print(" block_offsets_np.shape = {} data = {}".format(
block_offsets_np.shape, block_offsets_np))
slot_mapping_np = self.slot_mapping_np[:padded_batch_size]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[:, batch_size:] = _PAD_SLOT_ID
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
block_table_cpu = block_table_cpu[:len(decode_req_ids)]
print(" slot_mapping_np.shape = {} data = {}".format(
slot_mapping_np.shape, slot_mapping_np))
block_table_cpu = block_table_cpu[:padded_batch_size]
# Context lens
context_lens_np = self.decode_context_lens_np[:padded_batch_size]
@ -287,14 +327,17 @@ class TPUModelRunner(ModelRunnerBase):
context_lens_np[batch_size:] = 0
# Get final tensors
input_tokens = input_tokens_cpu.to(self.device)
input_positions = input_positions_cpu.to(self.device)
slot_mapping = self.slot_mapping_cpu[:padded_batch_size].to(
self.device)
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
slot_mapping = self.slot_mapping_cpu[:padded_batch_size].reshape(
-1, 1).to(self.device)
block_table = block_table_cpu.to(self.device)
context_lens = self.decode_context_lens_cpu[:padded_batch_size].to(
self.device)
print(" context_lens.shape = {} val = {}".format(
context_lens.shape, context_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,