mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix device handling in nn.utils.rnn.unpad_sequence
(#98042)
Without this change I get the following error. ``` line 444, in unpad_sequence mask = idx < length RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/98042 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c21cd2213
commit
fa1a8b9f96
@ -440,7 +440,7 @@ def unpad_sequence(
|
||||
padded_sequences.transpose_(0, 1)
|
||||
|
||||
max_length = padded_sequences.shape[1]
|
||||
idx = torch.arange(max_length)
|
||||
idx = torch.arange(max_length, device=lengths.device)
|
||||
|
||||
for seq, length in zip(padded_sequences, lengths):
|
||||
mask = idx < length
|
||||
|
Reference in New Issue
Block a user