Fix tensor device and dtype placement in Qwen2VL model (#26219)

Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Yuanfeng Li <yuanfengli@meta.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
yuafng
2025-10-04 06:41:39 -07:00
committed by GitHub
parent 4570535ec4
commit 86ee949128

View File

@ -720,7 +720,7 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw)
grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2],
grid_thw_[:, 0]).cumsum(
dim=0, dtype=torch.int32)