mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix tensor device and dtype placement in Qwen2VL model (#26219)
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Yuanfeng Li <yuanfengli@meta.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -720,7 +720,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# compute cu_seqlens
|
||||
grid_thw_ = torch.tensor(grid_thw)
|
||||
grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
|
||||
grid_thw_[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
|
Reference in New Issue
Block a user