Fix device handling in nn.utils.rnn.unpad_sequence (#98042)

Without this change I get the following error.
```
line 444, in unpad_sequence
    mask = idx < length
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98042
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Anton Bushuiev
2023-03-31 16:00:46 +00:00
committed by PyTorch MergeBot
parent 1c21cd2213
commit fa1a8b9f96

View File

@ -440,7 +440,7 @@ def unpad_sequence(
padded_sequences.transpose_(0, 1)
max_length = padded_sequences.shape[1]
idx = torch.arange(max_length)
idx = torch.arange(max_length, device=lengths.device)
for seq, length in zip(padded_sequences, lengths):
mask = idx < length