mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Remove cuda hard-code in compute_causal_conv1d_metadata (#25555)
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
nums_dict = {} # type: ignore
|
||||
batch_ptr = None
|
||||
token_chunk_offset_ptr = None
|
||||
device = query_start_loc_p.device
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
device=device)
|
||||
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
device=device)
|
||||
else:
|
||||
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
|
Reference in New Issue
Block a user