Add padding_side to pad_sequence with "left" and "right" options ("right" as default) (#131884)

Fixes #10536

Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work!

As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral.

The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](9d6c0641c4/src/transformers/tokenization_utils_base.py (L1565)), which requires only a very small change in the C++ code.

Here are the benchmarks of the new implementation!

`float32`:

![eaaa95ef-9384-45d2-be56-6898bc1d3514](https://github.com/user-attachments/assets/3b0eb309-e5a0-4a4d-97bb-4e3298783dbb)

`bool`:

![892f32da-8d9a-492b-9507-18d3f0a41e8e](https://github.com/user-attachments/assets/6824ea15-7d4e-4b89-95f0-8546635f0c2e)

Code:

```python
from __future__ import annotations

import random
import time
from typing import Literal

import numpy as np
import torch

def pad_sequence_with_flips(
    sequences: list[torch.Tensor],
    batch_first: bool = False,
    padding_value: int | float | bool = 0.0,
    padding_side: Literal["left", "right"] | str = "left",
) -> torch.Tensor:
    if padding_side == 'right':
        padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value)
    elif padding_side=='left':
        padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value)  # pyright: ignore[reportArgumentType]
        padded_sequence = padded_sequence.flip(int(batch_first))
    else:
        raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}")

    return padded_sequence

sequence_lengths: list[int] = []

flip_left_pad_times: list[float] = []
flip_left_pad_times_std: list[float] = []

left_pad_times: list[float] = []
left_pad_times_std: list[float] = []

RUNS_PER_LOOP: int = 100

for i in range(1, 7):
    sequence_length = i * int(1e6) // 6
    sequence_lengths.append(sequence_length)

    sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)]

    inner_left_pad_times: list[float] = []
    inner_right_pad_times: list[float] = []

    inner_flip_left_pad_times: list[float] = []
    inner_flip_right_pad_times: list[float] = []

    for _ in range(RUNS_PER_LOOP):

        start = time.perf_counter()
        torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left")
        end = time.perf_counter()
        inner_left_pad_times.append(end - start)

        start = time.perf_counter()
        pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left")
        end = time.perf_counter()
        inner_flip_left_pad_times.append(end - start)

    left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times))
    left_pad_times_std.append(np.std(inner_left_pad_times))

    flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times))
    flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times))

    print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}")

import matplotlib.pyplot as plt

plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left")
plt.scatter(sequence_lengths, left_pad_times)
plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^')

plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)")
plt.scatter(sequence_lengths, flip_left_pad_times)
plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^')

plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.legend(loc="upper right")

# Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794
# Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562
# Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863
# Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877
# Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024
# Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751
```

Co-authored-by: mskoh52 <mskoh52@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884
Approved by: https://github.com/ezyang
This commit is contained in:
Matthew Hoffman
2024-08-07 15:53:04 +00:00
committed by PyTorch MergeBot
parent 780310fed7
commit 258f47fc0b
8 changed files with 92 additions and 16 deletions

View File

@ -202,9 +202,11 @@ std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor
return std::make_tuple(output, lengths_t);
}
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value) {
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const c10::string_view padding_side) {
const int64_t sequences_size = sequences.size();
TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
TORCH_CHECK(padding_side == "left" || padding_side == "right",
"Expected padding_side to be one of left or right, but got ", padding_side, ".");
IntArrayRef max_size = sequences[0].sizes();
IntArrayRef trailing_dims = max_size.slice(1);
int64_t max_len = std::max_element(
@ -227,11 +229,12 @@ Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value
for (const auto i : c10::irange(sequences_size)) {
const Tensor& currseq = sequences[i];
const int64_t length_i = currseq.size(0);
const int64_t start = padding_side == "left" ? max_len - length_i : 0;
// use index notation to prevent duplicate references to the tensor
if (batch_first) {
out.select(0, i).narrow(0, 0, length_i).copy_(currseq);
out.select(0, i).narrow(0, start, length_i).copy_(currseq);
} else {
out.narrow(0, 0, length_i).select(1, i).copy_(currseq);
out.narrow(0, start, length_i).select(1, i).copy_(currseq);
}
}
return out;

View File

@ -14414,7 +14414,7 @@
CPU, CUDA: _segment_reduce_backward_kernel
autogen: _segment_reduce_backward.out
- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor
- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor
python_module: nn
variants: function

View File

@ -6,6 +6,7 @@
#include <test/cpp/api/support.h>
#include <algorithm>
#include <iostream>
#include <random>
#include <sstream>
#include <string>
@ -830,6 +831,15 @@ TEST_F(NNUtilsTest, PadSequence) {
padded = rnn_utils::pad_sequence({b, a, c});
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
// padding_side = "left", batch_first = true
expected = torch::tensor({{0, 4, 5}, {1, 2, 3}, {0, 0, 6}});
padded = rnn_utils::pad_sequence({b, a, c}, true, 0, "left");
ASSERT_TRUE(padded.allclose(expected));
// padding_side = "left", batch_first = false
padded = rnn_utils::pad_sequence({b, a, c}, false, 0, "left");
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
// pad with non-zero value
expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}});
padded = rnn_utils::pad_sequence({b, a, c}, true, 1);
@ -870,5 +880,21 @@ TEST_F(NNUtilsTest, PadSequence) {
// batch first = false
padded = rnn_utils::pad_sequence(sequences);
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
// reset expected_tensors for padding_side
expected_tensors.clear();
for (const torch::Tensor& seq : sequences) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
expected_tensors.emplace_back(
torch::flip(pad(torch::flip(seq, {0}), maxlen * maxlen), {0}));
}
expected = torch::stack(expected_tensors);
// padding_side = "left", batch_first = true
padded = rnn_utils::pad_sequence(sequences, true, 0, "left");
ASSERT_TRUE(padded.allclose(expected));
// padding_side = "left", batch_first = false
padded = rnn_utils::pad_sequence(sequences, false, 0, "left");
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
}
}

View File

@ -2,6 +2,7 @@
import itertools
import random
from typing import List
import torch
import torch.nn.utils.rnn as rnn_utils
@ -188,6 +189,23 @@ class PackedSequenceTest(TestCase):
padded = rnn_utils.pad_sequence([b, a, c])
self.assertEqual(padded, expected.transpose(0, 1))
# padding_side = "left", batch_first=True
expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
padded = rnn_utils.pad_sequence(
[b, a, c],
batch_first=True,
padding_side="left",
)
self.assertEqual(padded, expected)
# padding_side = "left", batch_first=False
padded = rnn_utils.pad_sequence(
[b, a, c],
batch_first=False,
padding_side="left",
)
self.assertEqual(padded, expected.transpose(0, 1))
# pad with non-zero value
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
@ -201,17 +219,14 @@ class PackedSequenceTest(TestCase):
# more dimensions
maxlen = 9
for num_dim in (0, 1, 2, 3):
sequences = []
sequences: List[torch.Tensor] = []
trailing_dims = [4] * num_dim
for i in range(1, maxlen + 1):
seq_len = i * i
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
random.shuffle(sequences)
expected = []
for seq in sequences:
expected.append(pad(seq, maxlen * maxlen))
# batch first = true
expected = torch.stack(expected)
expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
padded = rnn_utils.pad_sequence(sequences, True)
self.assertEqual(padded, expected)
@ -219,6 +234,25 @@ class PackedSequenceTest(TestCase):
padded = rnn_utils.pad_sequence(sequences)
self.assertEqual(padded, expected.transpose(0, 1))
# padding_side = "left", batch_first=True
expected = torch.stack(
[pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
)
padded = rnn_utils.pad_sequence(
sequences,
batch_first=True,
padding_side="left",
)
self.assertEqual(padded, expected)
# padding_side = "left", batch_first=False
padded = rnn_utils.pad_sequence(
sequences,
batch_first=False,
padding_side="left",
)
self.assertEqual(padded, expected.transpose(0, 1))
def test_unpad_sequence(self):
# single dimensional
a = torch.tensor([1, 2, 3])

View File

@ -9706,9 +9706,9 @@ dedent """
def test_script_pad_sequence_pack_sequence(self):
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0):
# type: (List[Tensor], bool, float) -> Tensor
return pad_sequence(tensor_list, batch_first, padding_value)
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"):
# type: (List[Tensor], bool, float, str) -> Tensor
return pad_sequence(tensor_list, batch_first, padding_value, padding_side)
def pack_sequence_func(tensor_list, enforce_sorted=True):
# type: (List[Tensor], bool) -> Tensor
@ -9727,6 +9727,10 @@ dedent """
([ones3, ones4, ones5], True))
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5], True, 2.5))
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5], True, 2.5, "left"))
self.checkScript(pad_sequence_func,
([ones3, ones4, ones5], False, 2.5, "left"))
self.checkScript(pack_sequence_func,
([tensor1, tensor2, tensor3],))
self.checkScript(pack_sequence_func,

View File

@ -1,7 +1,7 @@
# ${generated_comment}
# mypy: disable-error-code="type-arg"
from typing import List, Optional, overload, Sequence, Tuple, Union
from typing import List, Literal, Optional, overload, Sequence, Tuple, Union
from torch import memory_format, Tensor
from torch.types import _bool, _device, _dtype, _int, _size
@ -64,6 +64,7 @@ def pad_sequence(
sequences: Union[List[Tensor], Tuple[Tensor, ...]],
batch_first: bool = False,
padding_value: float = 0.0,
padding_side: Union[Literal["left", "right"], str] = "right",
) -> Tensor: ...
def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...

View File

@ -300,6 +300,8 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
/// or in
/// ``T x B x *`` otherwise
/// padding_value (double, optional): value for padded elements. Default: 0.
/// padding_side (str, optional): the side to pad the sequences on. Default:
/// "right".
///
/// Returns:
/// Tensor of size ``T x B x *`` if `batch_first` is ``false``.
@ -307,8 +309,9 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
inline Tensor pad_sequence(
ArrayRef<Tensor> sequences,
bool batch_first = false,
double padding_value = 0) {
return at::pad_sequence(sequences, batch_first, padding_value);
double padding_value = 0,
c10::string_view padding_side = "right") {
return at::pad_sequence(sequences, batch_first, padding_value, padding_side);
}
/// Packs a list of variable length Tensors

View File

@ -419,6 +419,7 @@ def pad_sequence(
sequences: Union[Tensor, List[Tensor]],
batch_first: bool = False,
padding_value: float = 0.0,
padding_side: str = "right",
) -> Tensor:
r"""Pad a list of variable length Tensors with :attr:`padding_value`.
@ -448,6 +449,8 @@ def pad_sequence(
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
format, ``T x B x *`` otherwise.
padding_value (float, optional): value for padded elements. Default: 0.
padding_side (str, optional): the side to pad the sequences on.
Default: "right".
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
@ -472,7 +475,9 @@ def pad_sequence(
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) # type: ignore[arg-type]
return torch._C._nn.pad_sequence(
sequences, batch_first, padding_value, padding_side # type: ignore[arg-type]
)
def unpad_sequence(