mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add padding_side
to pad_sequence
with "left"
and "right"
options ("right"
as default) (#131884)
Fixes #10536
Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work!
As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral.
The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](9d6c0641c4/src/transformers/tokenization_utils_base.py (L1565)
), which requires only a very small change in the C++ code.
Here are the benchmarks of the new implementation!
`float32`:

`bool`:

Code:
```python
from __future__ import annotations
import random
import time
from typing import Literal
import numpy as np
import torch
def pad_sequence_with_flips(
sequences: list[torch.Tensor],
batch_first: bool = False,
padding_value: int | float | bool = 0.0,
padding_side: Literal["left", "right"] | str = "left",
) -> torch.Tensor:
if padding_side == 'right':
padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value)
elif padding_side=='left':
padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value) # pyright: ignore[reportArgumentType]
padded_sequence = padded_sequence.flip(int(batch_first))
else:
raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}")
return padded_sequence
sequence_lengths: list[int] = []
flip_left_pad_times: list[float] = []
flip_left_pad_times_std: list[float] = []
left_pad_times: list[float] = []
left_pad_times_std: list[float] = []
RUNS_PER_LOOP: int = 100
for i in range(1, 7):
sequence_length = i * int(1e6) // 6
sequence_lengths.append(sequence_length)
sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)]
inner_left_pad_times: list[float] = []
inner_right_pad_times: list[float] = []
inner_flip_left_pad_times: list[float] = []
inner_flip_right_pad_times: list[float] = []
for _ in range(RUNS_PER_LOOP):
start = time.perf_counter()
torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_left_pad_times.append(end - start)
start = time.perf_counter()
pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_flip_left_pad_times.append(end - start)
left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times))
left_pad_times_std.append(np.std(inner_left_pad_times))
flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times))
flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times))
print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}")
import matplotlib.pyplot as plt
plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left")
plt.scatter(sequence_lengths, left_pad_times)
plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^')
plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)")
plt.scatter(sequence_lengths, flip_left_pad_times)
plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^')
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.legend(loc="upper right")
# Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794
# Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562
# Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863
# Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877
# Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024
# Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751
```
Co-authored-by: mskoh52 <mskoh52@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
780310fed7
commit
258f47fc0b
@ -202,9 +202,11 @@ std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor
|
||||
return std::make_tuple(output, lengths_t);
|
||||
}
|
||||
|
||||
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value) {
|
||||
Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value, const c10::string_view padding_side) {
|
||||
const int64_t sequences_size = sequences.size();
|
||||
TORCH_CHECK(sequences_size > 0, "received an empty list of sequences");
|
||||
TORCH_CHECK(padding_side == "left" || padding_side == "right",
|
||||
"Expected padding_side to be one of left or right, but got ", padding_side, ".");
|
||||
IntArrayRef max_size = sequences[0].sizes();
|
||||
IntArrayRef trailing_dims = max_size.slice(1);
|
||||
int64_t max_len = std::max_element(
|
||||
@ -227,11 +229,12 @@ Tensor pad_sequence(TensorList sequences, bool batch_first, double padding_value
|
||||
for (const auto i : c10::irange(sequences_size)) {
|
||||
const Tensor& currseq = sequences[i];
|
||||
const int64_t length_i = currseq.size(0);
|
||||
const int64_t start = padding_side == "left" ? max_len - length_i : 0;
|
||||
// use index notation to prevent duplicate references to the tensor
|
||||
if (batch_first) {
|
||||
out.select(0, i).narrow(0, 0, length_i).copy_(currseq);
|
||||
out.select(0, i).narrow(0, start, length_i).copy_(currseq);
|
||||
} else {
|
||||
out.narrow(0, 0, length_i).select(1, i).copy_(currseq);
|
||||
out.narrow(0, start, length_i).select(1, i).copy_(currseq);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
|
@ -14414,7 +14414,7 @@
|
||||
CPU, CUDA: _segment_reduce_backward_kernel
|
||||
autogen: _segment_reduce_backward.out
|
||||
|
||||
- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor
|
||||
- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor
|
||||
python_module: nn
|
||||
variants: function
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -830,6 +831,15 @@ TEST_F(NNUtilsTest, PadSequence) {
|
||||
padded = rnn_utils::pad_sequence({b, a, c});
|
||||
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
|
||||
|
||||
// padding_side = "left", batch_first = true
|
||||
expected = torch::tensor({{0, 4, 5}, {1, 2, 3}, {0, 0, 6}});
|
||||
padded = rnn_utils::pad_sequence({b, a, c}, true, 0, "left");
|
||||
ASSERT_TRUE(padded.allclose(expected));
|
||||
|
||||
// padding_side = "left", batch_first = false
|
||||
padded = rnn_utils::pad_sequence({b, a, c}, false, 0, "left");
|
||||
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
|
||||
|
||||
// pad with non-zero value
|
||||
expected = torch::tensor({{4, 5, 1}, {1, 2, 3}, {6, 1, 1}});
|
||||
padded = rnn_utils::pad_sequence({b, a, c}, true, 1);
|
||||
@ -870,5 +880,21 @@ TEST_F(NNUtilsTest, PadSequence) {
|
||||
// batch first = false
|
||||
padded = rnn_utils::pad_sequence(sequences);
|
||||
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
|
||||
|
||||
// reset expected_tensors for padding_side
|
||||
expected_tensors.clear();
|
||||
for (const torch::Tensor& seq : sequences) {
|
||||
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
|
||||
expected_tensors.emplace_back(
|
||||
torch::flip(pad(torch::flip(seq, {0}), maxlen * maxlen), {0}));
|
||||
}
|
||||
expected = torch::stack(expected_tensors);
|
||||
// padding_side = "left", batch_first = true
|
||||
padded = rnn_utils::pad_sequence(sequences, true, 0, "left");
|
||||
ASSERT_TRUE(padded.allclose(expected));
|
||||
|
||||
// padding_side = "left", batch_first = false
|
||||
padded = rnn_utils::pad_sequence(sequences, false, 0, "left");
|
||||
ASSERT_TRUE(padded.allclose(expected.transpose(0, 1)));
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
@ -188,6 +189,23 @@ class PackedSequenceTest(TestCase):
|
||||
padded = rnn_utils.pad_sequence([b, a, c])
|
||||
self.assertEqual(padded, expected.transpose(0, 1))
|
||||
|
||||
# padding_side = "left", batch_first=True
|
||||
expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
|
||||
padded = rnn_utils.pad_sequence(
|
||||
[b, a, c],
|
||||
batch_first=True,
|
||||
padding_side="left",
|
||||
)
|
||||
self.assertEqual(padded, expected)
|
||||
|
||||
# padding_side = "left", batch_first=False
|
||||
padded = rnn_utils.pad_sequence(
|
||||
[b, a, c],
|
||||
batch_first=False,
|
||||
padding_side="left",
|
||||
)
|
||||
self.assertEqual(padded, expected.transpose(0, 1))
|
||||
|
||||
# pad with non-zero value
|
||||
expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
|
||||
padded = rnn_utils.pad_sequence([b, a, c], True, 1)
|
||||
@ -201,17 +219,14 @@ class PackedSequenceTest(TestCase):
|
||||
# more dimensions
|
||||
maxlen = 9
|
||||
for num_dim in (0, 1, 2, 3):
|
||||
sequences = []
|
||||
sequences: List[torch.Tensor] = []
|
||||
trailing_dims = [4] * num_dim
|
||||
for i in range(1, maxlen + 1):
|
||||
seq_len = i * i
|
||||
sequences.append(torch.rand(seq_len, 5, *trailing_dims))
|
||||
random.shuffle(sequences)
|
||||
expected = []
|
||||
for seq in sequences:
|
||||
expected.append(pad(seq, maxlen * maxlen))
|
||||
# batch first = true
|
||||
expected = torch.stack(expected)
|
||||
expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
|
||||
padded = rnn_utils.pad_sequence(sequences, True)
|
||||
self.assertEqual(padded, expected)
|
||||
|
||||
@ -219,6 +234,25 @@ class PackedSequenceTest(TestCase):
|
||||
padded = rnn_utils.pad_sequence(sequences)
|
||||
self.assertEqual(padded, expected.transpose(0, 1))
|
||||
|
||||
# padding_side = "left", batch_first=True
|
||||
expected = torch.stack(
|
||||
[pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
|
||||
)
|
||||
padded = rnn_utils.pad_sequence(
|
||||
sequences,
|
||||
batch_first=True,
|
||||
padding_side="left",
|
||||
)
|
||||
self.assertEqual(padded, expected)
|
||||
|
||||
# padding_side = "left", batch_first=False
|
||||
padded = rnn_utils.pad_sequence(
|
||||
sequences,
|
||||
batch_first=False,
|
||||
padding_side="left",
|
||||
)
|
||||
self.assertEqual(padded, expected.transpose(0, 1))
|
||||
|
||||
def test_unpad_sequence(self):
|
||||
# single dimensional
|
||||
a = torch.tensor([1, 2, 3])
|
||||
|
@ -9706,9 +9706,9 @@ dedent """
|
||||
def test_script_pad_sequence_pack_sequence(self):
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
|
||||
|
||||
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0):
|
||||
# type: (List[Tensor], bool, float) -> Tensor
|
||||
return pad_sequence(tensor_list, batch_first, padding_value)
|
||||
def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"):
|
||||
# type: (List[Tensor], bool, float, str) -> Tensor
|
||||
return pad_sequence(tensor_list, batch_first, padding_value, padding_side)
|
||||
|
||||
def pack_sequence_func(tensor_list, enforce_sorted=True):
|
||||
# type: (List[Tensor], bool) -> Tensor
|
||||
@ -9727,6 +9727,10 @@ dedent """
|
||||
([ones3, ones4, ones5], True))
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5], True, 2.5))
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5], True, 2.5, "left"))
|
||||
self.checkScript(pad_sequence_func,
|
||||
([ones3, ones4, ones5], False, 2.5, "left"))
|
||||
self.checkScript(pack_sequence_func,
|
||||
([tensor1, tensor2, tensor3],))
|
||||
self.checkScript(pack_sequence_func,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# ${generated_comment}
|
||||
# mypy: disable-error-code="type-arg"
|
||||
|
||||
from typing import List, Optional, overload, Sequence, Tuple, Union
|
||||
from typing import List, Literal, Optional, overload, Sequence, Tuple, Union
|
||||
|
||||
from torch import memory_format, Tensor
|
||||
from torch.types import _bool, _device, _dtype, _int, _size
|
||||
@ -64,6 +64,7 @@ def pad_sequence(
|
||||
sequences: Union[List[Tensor], Tuple[Tensor, ...]],
|
||||
batch_first: bool = False,
|
||||
padding_value: float = 0.0,
|
||||
padding_side: Union[Literal["left", "right"], str] = "right",
|
||||
) -> Tensor: ...
|
||||
def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
|
||||
def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
|
||||
|
@ -300,6 +300,8 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
|
||||
/// or in
|
||||
/// ``T x B x *`` otherwise
|
||||
/// padding_value (double, optional): value for padded elements. Default: 0.
|
||||
/// padding_side (str, optional): the side to pad the sequences on. Default:
|
||||
/// "right".
|
||||
///
|
||||
/// Returns:
|
||||
/// Tensor of size ``T x B x *`` if `batch_first` is ``false``.
|
||||
@ -307,8 +309,9 @@ inline std::tuple<Tensor, Tensor> pad_packed_sequence(
|
||||
inline Tensor pad_sequence(
|
||||
ArrayRef<Tensor> sequences,
|
||||
bool batch_first = false,
|
||||
double padding_value = 0) {
|
||||
return at::pad_sequence(sequences, batch_first, padding_value);
|
||||
double padding_value = 0,
|
||||
c10::string_view padding_side = "right") {
|
||||
return at::pad_sequence(sequences, batch_first, padding_value, padding_side);
|
||||
}
|
||||
|
||||
/// Packs a list of variable length Tensors
|
||||
|
@ -419,6 +419,7 @@ def pad_sequence(
|
||||
sequences: Union[Tensor, List[Tensor]],
|
||||
batch_first: bool = False,
|
||||
padding_value: float = 0.0,
|
||||
padding_side: str = "right",
|
||||
) -> Tensor:
|
||||
r"""Pad a list of variable length Tensors with :attr:`padding_value`.
|
||||
|
||||
@ -448,6 +449,8 @@ def pad_sequence(
|
||||
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
|
||||
format, ``T x B x *`` otherwise.
|
||||
padding_value (float, optional): value for padded elements. Default: 0.
|
||||
padding_side (str, optional): the side to pad the sequences on.
|
||||
Default: "right".
|
||||
|
||||
Returns:
|
||||
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
||||
@ -472,7 +475,9 @@ def pad_sequence(
|
||||
|
||||
# assuming trailing dimensions and type of all the Tensors
|
||||
# in sequences are same and fetching those from sequences[0]
|
||||
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) # type: ignore[arg-type]
|
||||
return torch._C._nn.pad_sequence(
|
||||
sequences, batch_first, padding_value, padding_side # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def unpad_sequence(
|
||||
|
Reference in New Issue
Block a user