Files
pytorch/test/nn/test_packed_sequence.py
Matthew Hoffman 258f47fc0b Add padding_side to pad_sequence with "left" and "right" options ("right" as default) (#131884)
Fixes #10536

Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work!

As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral.

The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](9d6c0641c4/src/transformers/tokenization_utils_base.py (L1565)), which requires only a very small change in the C++ code.

Here are the benchmarks of the new implementation!

`float32`:

![eaaa95ef-9384-45d2-be56-6898bc1d3514](https://github.com/user-attachments/assets/3b0eb309-e5a0-4a4d-97bb-4e3298783dbb)

`bool`:

![892f32da-8d9a-492b-9507-18d3f0a41e8e](https://github.com/user-attachments/assets/6824ea15-7d4e-4b89-95f0-8546635f0c2e)

Code:

```python
from __future__ import annotations

import random
import time
from typing import Literal

import numpy as np
import torch

def pad_sequence_with_flips(
    sequences: list[torch.Tensor],
    batch_first: bool = False,
    padding_value: int | float | bool = 0.0,
    padding_side: Literal["left", "right"] | str = "left",
) -> torch.Tensor:
    if padding_side == 'right':
        padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value)
    elif padding_side=='left':
        padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value)  # pyright: ignore[reportArgumentType]
        padded_sequence = padded_sequence.flip(int(batch_first))
    else:
        raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}")

    return padded_sequence

sequence_lengths: list[int] = []

flip_left_pad_times: list[float] = []
flip_left_pad_times_std: list[float] = []

left_pad_times: list[float] = []
left_pad_times_std: list[float] = []

RUNS_PER_LOOP: int = 100

for i in range(1, 7):
    sequence_length = i * int(1e6) // 6
    sequence_lengths.append(sequence_length)

    sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)]

    inner_left_pad_times: list[float] = []
    inner_right_pad_times: list[float] = []

    inner_flip_left_pad_times: list[float] = []
    inner_flip_right_pad_times: list[float] = []

    for _ in range(RUNS_PER_LOOP):

        start = time.perf_counter()
        torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left")
        end = time.perf_counter()
        inner_left_pad_times.append(end - start)

        start = time.perf_counter()
        pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left")
        end = time.perf_counter()
        inner_flip_left_pad_times.append(end - start)

    left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times))
    left_pad_times_std.append(np.std(inner_left_pad_times))

    flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times))
    flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times))

    print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}")

import matplotlib.pyplot as plt

plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left")
plt.scatter(sequence_lengths, left_pad_times)
plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^')

plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)")
plt.scatter(sequence_lengths, flip_left_pad_times)
plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^')

plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.legend(loc="upper right")

# Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794
# Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562
# Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863
# Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877
# Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024
# Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751
```

Co-authored-by: mskoh52 <mskoh52@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884
Approved by: https://github.com/ezyang
2024-08-07 15:53:07 +00:00

500 lines
20 KiB
Python

# Owner(s): ["module: nn"]
import itertools
import random
from typing import List
import torch
import torch.nn.utils.rnn as rnn_utils
from torch.testing._internal.common_utils import run_tests, TestCase
class PackedSequenceTest(TestCase):
_type_by_name = {
"torch.DoubleTensor": (torch.DoubleTensor, "double"),
"torch.FloatTensor": (torch.FloatTensor, "float"),
# We leave out `'torch.HalfTensor': (torch.HalfTensor, 'half'),`
# because of an error in `pad_packed_sequence`
# > AttributeError: 'torch.HalfTensor' object has no attribute 'fill_'
"torch.LongTensor": (torch.LongTensor, "long"),
"torch.IntTensor": (torch.IntTensor, "int"),
"torch.ShortTensor": (torch.ShortTensor, "short"),
"torch.CharTensor": (torch.CharTensor, "char"),
"torch.ByteTensor": (torch.ByteTensor, "byte"),
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = 5
self.max_length = 6
def _ordered_sequence(self, tensor_type):
"""Create ordered list of random sequences"""
seqs = [
tensor_type(random.randint(1, self.max_length))
for _ in range(self.batch_size)
]
if tensor_type == torch.ByteTensor:
seqs = [s.random_(0, 256) for s in seqs]
else:
seqs = [s.random_(-128, 128) for s in seqs]
ordered = sorted(seqs, key=len, reverse=True)
return ordered
def _padded_sequence(self, tensor_type):
"""Create Tensor of random padded sequences"""
ordered = self._ordered_sequence(tensor_type)
lengths = [len(i) for i in ordered]
padded_tensor = rnn_utils.pad_sequence(ordered)
return padded_tensor, lengths
def test_type_casts(self):
"""Test type casting of `PackedSequence` against type casting of tensor"""
for input_type, _ in self._type_by_name.values():
for expected_type_str, (_, cast_str) in self._type_by_name.items():
for enforce_sorted in [True, False]:
padded, lengths = self._padded_sequence(input_type)
packed = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=enforce_sorted
)
# Apply cast to `PackedSequence` instance and unpack
masked = getattr(packed, cast_str)()
unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
self.assertEqual(unpacked.type(), expected_type_str)
def test_wrong_order(self):
a = torch.ones(25, 300)
b = torch.ones(22, 300)
b_a = rnn_utils.pad_sequence([b, a])
self.assertRaises(
RuntimeError,
lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True),
)
def test_pad_sequence_with_tensor_sequences(self):
seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
(torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
)
seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
torch.tensor([[[7, 6]], [[-7, -1]]])
)
self.assertEqual(seq_tuple_input, seq_tensor_input)
self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
def test_pad_sequence_with_non_iterable_sequences(self):
msg = r"Expected iterable for input sequences, but got arg of type"
with self.assertRaisesRegex(RuntimeError, msg):
torch.nn.utils.rnn.pad_sequence(5)
def test_total_length(self):
padded, lengths = self._padded_sequence(torch.FloatTensor)
max_length = max(lengths)
packed = rnn_utils.pack_padded_sequence(padded, lengths)
# test ValueError if total_length < max_length
for total_length in (-1, 0, max_length - 1):
for batch_first in (True, False):
def err_fn():
rnn_utils.pad_packed_sequence(
packed, batch_first=batch_first, total_length=total_length
)
self.assertRaisesRegex(
ValueError,
r"Expected total_length to be at least the "
r"length of the longest sequence in input",
err_fn,
)
# test that pad_packed_sequence returns results of correct length
for batch_first in (True, False):
no_extra_pad, _ = rnn_utils.pad_packed_sequence(
packed, batch_first=batch_first
)
for total_length_delta in (0, 1, 8):
total_length = max_length + total_length_delta
unpacked, lengths_out = rnn_utils.pad_packed_sequence(
packed, batch_first=batch_first, total_length=total_length
)
self.assertEqual(lengths, lengths_out)
self.assertEqual(unpacked.size(1 if batch_first else 0), total_length)
if total_length_delta == 0:
ref_output = no_extra_pad
elif batch_first:
extra_pad = no_extra_pad.new_zeros(
self.batch_size, total_length_delta
)
ref_output = torch.cat([no_extra_pad, extra_pad], 1)
else:
extra_pad = no_extra_pad.new_zeros(
total_length_delta, self.batch_size
)
ref_output = torch.cat([no_extra_pad, extra_pad], 0)
self.assertEqual(unpacked, ref_output)
def test_to(self):
for enforce_sorted in (True, False):
padded, lengths = self._padded_sequence(torch.IntTensor)
a = rnn_utils.pack_padded_sequence(
padded, lengths, enforce_sorted=enforce_sorted
).cpu()
self.assertIs(a, a.to("cpu"))
self.assertIs(a, a.cpu())
self.assertIs(a, a.to("cpu", dtype=torch.int32))
self.assertEqual(a.long(), a.to(torch.int64))
if torch.cuda.is_available():
for cuda in [
"cuda",
"cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
]:
b = a.cuda(device=cuda)
self.assertIs(b, b.to(cuda))
self.assertIs(b, b.cuda())
self.assertEqual(a, b.to("cpu"))
self.assertEqual(b, a.to(cuda))
self.assertEqual(a, b.to("cpu", dtype=torch.int32))
self.assertIs(b, b.to(dtype=torch.int32))
self.assertEqual(b.long(), b.to(dtype=torch.int64))
def test_to_memory_format(self):
m = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, bias=True)
m = m.to(memory_format=torch.channels_last)
for param in m.parameters():
if param.dim() == 4:
self.assertTrue(param.is_contiguous(memory_format=torch.channels_last))
def test_pad_sequence(self):
def pad(tensor, length):
return torch.cat(
[
tensor.data,
tensor.data.new(
length - tensor.size(0), *tensor.size()[1:]
).zero_(),
]
)
# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
# batch_first = true
expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
padded = rnn_utils.pad_sequence([b, a, c], True)
self.assertEqual(padded, expected)
# batch_first = false
padded = rnn_utils.pad_sequence([b, a, c])
self.assertEqual(padded, expected.transpose(0, 1))
# padding_side = "left", batch_first=True
expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
padded = rnn_utils.pad_sequence(
[b, a, c],
batch_first=True,
padding_side="left",
)
self.assertEqual(padded, expected)
# padding_side = "left", batch_first=False
padded = rnn_utils.pad_sequence(
[b, a, c],
batch_first=False,
padding_side="left",
)
self.assertEqual(padded, expected.transpose(0, 1))
# pad with non-zero value
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
self.assertEqual(padded, expected)
# Test pad sorted sequence
expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
padded = rnn_utils.pad_sequence([a, b, c], True)
self.assertEqual(padded, expected)
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences: List[torch.Tensor] = []
trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1):
seq_len = i * i
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
random.shuffle(sequences)
# batch first = true
expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
padded = rnn_utils.pad_sequence(sequences, True)
self.assertEqual(padded, expected)
# batch first = false
padded = rnn_utils.pad_sequence(sequences)
self.assertEqual(padded, expected.transpose(0, 1))
# padding_side = "left", batch_first=True
expected = torch.stack(
[pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
)
padded = rnn_utils.pad_sequence(
sequences,
batch_first=True,
padding_side="left",
)
self.assertEqual(padded, expected)
# padding_side = "left", batch_first=False
padded = rnn_utils.pad_sequence(
sequences,
batch_first=False,
padding_side="left",
)
self.assertEqual(padded, expected.transpose(0, 1))
def test_unpad_sequence(self):
# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
sequences = [a, b, c]
lengths = torch.as_tensor([v.size(0) for v in sequences])
for batch_first in [True, False]:
padded_sequences = rnn_utils.pad_sequence(
sequences, batch_first=batch_first
)
unpadded_sequences = rnn_utils.unpad_sequence(
padded_sequences, lengths, batch_first=batch_first
)
self.assertEqual(sequences, unpadded_sequences)
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences = []
trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1):
seq_len = i * i
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
random.shuffle(sequences)
lengths = torch.as_tensor([v.size(0) for v in sequences])
padded_sequences = rnn_utils.pad_sequence(
sequences, batch_first=batch_first
)
unpadded_sequences = rnn_utils.unpad_sequence(
padded_sequences, lengths, batch_first=batch_first
)
self.assertEqual(sequences, unpadded_sequences)
def test_pack_sequence(self):
def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
padded = rnn_utils.pad_sequence(sequences, batch_first)
packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
self.assertEqual(padded, unpacked[0])
pack_padded = rnn_utils.pack_padded_sequence(
padded, lengths, batch_first, enforce_sorted
)
self.assertEqual(packed, pack_padded)
# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
expected = torch.tensor([1, 4, 6, 2, 5, 3])
self.assertEqual(packed.batch_sizes, [3, 2, 1])
self.assertEqual(packed.data.data, expected)
self.assertEqual(packed.sorted_indices, [0, 1, 2])
self.assertEqual(packed.unsorted_indices, [0, 1, 2])
packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
self.assertEqual(packed_unsorted.data.data, expected)
self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])
# single dimensional, enforce_sorted = True
packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
self.assertEqual(packed_enforce_sorted.data.data, expected)
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
with self.assertRaisesRegex(RuntimeError, "must be sorted in decreasing order"):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
with self.assertRaisesRegex(
RuntimeError, "You can pass `enforce_sorted=False`"
):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences = []
lengths = []
trailing_dims = [4] * num_dim
for i in range(maxlen, 0, -1):
seq_len = i * i
lengths.append(seq_len)
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
unsorted_sequences = [s.clone() for s in sequences]
random.shuffle(unsorted_sequences)
unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]
# compatibility with other utilities
for batch_first in (True, False):
for enforce_sorted in (True, False):
_compatibility_test(sequences, lengths, batch_first, enforce_sorted)
_compatibility_test(
unsorted_sequences, unsorted_sequences_lengths, batch_first
)
def test_unpack_sequence(self):
# single dimensional
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
sequences = [a, b, c]
packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
self.assertEqual(sequences, unpacked_sequences)
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences = []
trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1):
seq_len = i * i
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
random.shuffle(sequences)
packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
self.assertEqual(sequences, unpacked_sequences)
def test_pack_padded_sequence(self):
def generate_test_case(sorted_lengths, should_shuffle):
def pad(tensor, length):
return torch.cat(
[
tensor,
tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_(),
]
)
max_length = sorted_lengths[0]
batch_sizes = [
sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
for i in range(1, max_length + 1)
]
offset = 0
padded = torch.cat(
[
pad(
i * 100 + torch.arange(1.0, 5 * l + 1).view(l, 1, 5), max_length
)
for i, l in enumerate(sorted_lengths, 1)
],
1,
)
expected_data = [
[
torch.arange(1.0, 6) + (i + 1) * 100 + 5 * n
for i in range(batch_size)
]
for n, batch_size in enumerate(batch_sizes)
]
expected_data = list(itertools.chain.from_iterable(expected_data))
expected_data = torch.stack(expected_data, dim=0)
if should_shuffle:
# Shuffle the padded sequence to create an unsorted sequence
permutation = list(range(len(sorted_lengths)))
random.shuffle(permutation)
unsorted_indices = torch.tensor(permutation)
padded = padded.index_select(1, unsorted_indices)
lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
else:
unsorted_indices = None
lengths = sorted_lengths
return (
padded.requires_grad_(),
lengths,
expected_data,
batch_sizes,
unsorted_indices,
)
test_cases = [
# sorted_lengths, should_shuffle
[[10, 8, 4, 2, 2, 2, 1], False],
[[11, 10, 8, 6, 4, 3, 1], False],
[[11, 10, 8, 6, 4, 3, 1], True],
]
for test_case, batch_first in itertools.product(test_cases, (True, False)):
sorted_lengths, should_shuffle = test_case
(
padded,
lengths,
expected_data,
batch_sizes,
unsorted_indices,
) = generate_test_case(sorted_lengths, should_shuffle)
src = padded
if batch_first:
src = src.transpose(0, 1)
# check output
packed = rnn_utils.pack_padded_sequence(
src, lengths, batch_first=batch_first, enforce_sorted=not should_shuffle
)
self.assertEqual(packed.data.data, expected_data)
self.assertEqual(packed.batch_sizes, batch_sizes)
self.assertEqual(packed.unsorted_indices, unsorted_indices)
# test inverse
unpacked, unpacked_len = rnn_utils.pad_packed_sequence(
packed, batch_first=batch_first
)
self.assertEqual(unpacked, src)
self.assertEqual(unpacked_len, lengths)
# check grad
if padded.grad is not None:
padded.grad.data.zero_()
grad_output = unpacked.data.clone().normal_()
unpacked.backward(grad_output)
if batch_first:
grad_output.transpose_(0, 1)
for i, l in enumerate(lengths):
self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i])
if l < 10:
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
# test error messages
with self.assertRaisesRegex(
RuntimeError, "You can pass `enforce_sorted=False`"
):
packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
with self.assertRaisesRegex(RuntimeError, "empty tensor"):
packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), [])
with self.assertRaisesRegex(RuntimeError, "empty tensor"):
packed = rnn_utils.pack_padded_sequence(
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
)
if __name__ == "__main__":
run_tests()