Remove cuda hard-code in compute_causal_conv1d_metadata (#25555)

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-09-26 16:19:20 +08:00
committed by GitHub
parent 99b3a504c5
commit dd70437a4f

View File

@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
nums_dict = {} # type: ignore
batch_ptr = None
token_chunk_offset_ptr = None
device = query_start_loc_p.device
for BLOCK_M in [8]: # cover all BLOCK_M values
nums = -(-seqlens // BLOCK_M)
nums_dict[BLOCK_M] = {}
@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
device=device)
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device='cuda')
device=device)
else:
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)