Fixes#10536
Reattempt of #61467. Thank you so much to @mskoh52 for your excellent work!
As I was trying to create a more efficient LLM data collator, I realized that `pad_sequence` only supports right padding, even though left padding is a very common format for LLMs, like Llama and Mistral.
The proposed alternative implementation was to use multiple flips, which tends to be 1.5x-2x slower. Instead we can add a [`padding_side` parameter as there is for for Hugging Face tokenizers](9d6c0641c4/src/transformers/tokenization_utils_base.py (L1565)), which requires only a very small change in the C++ code.
Here are the benchmarks of the new implementation!
`float32`:

`bool`:

Code:
```python
from __future__ import annotations
import random
import time
from typing import Literal
import numpy as np
import torch
def pad_sequence_with_flips(
sequences: list[torch.Tensor],
batch_first: bool = False,
padding_value: int | float | bool = 0.0,
padding_side: Literal["left", "right"] | str = "left",
) -> torch.Tensor:
if padding_side == 'right':
padded_sequence = torch._C._nn.pad_sequence([t.flatten() for t in sequences], batch_first=batch_first, padding_value=padding_value)
elif padding_side=='left':
padded_sequence = torch._C._nn.pad_sequence([t.flatten().flip(0) for t in sequences], batch_first=batch_first, padding_value=padding_value) # pyright: ignore[reportArgumentType]
padded_sequence = padded_sequence.flip(int(batch_first))
else:
raise ValueError(f"padding_side should be either 'right' or 'left', but got {padding_side}")
return padded_sequence
sequence_lengths: list[int] = []
flip_left_pad_times: list[float] = []
flip_left_pad_times_std: list[float] = []
left_pad_times: list[float] = []
left_pad_times_std: list[float] = []
RUNS_PER_LOOP: int = 100
for i in range(1, 7):
sequence_length = i * int(1e6) // 6
sequence_lengths.append(sequence_length)
sequences = [torch.randint(0, 2, (random.randint(1, sequence_length),), dtype=torch.bool) for _ in range(64)]
inner_left_pad_times: list[float] = []
inner_right_pad_times: list[float] = []
inner_flip_left_pad_times: list[float] = []
inner_flip_right_pad_times: list[float] = []
for _ in range(RUNS_PER_LOOP):
start = time.perf_counter()
torch._C._nn.pad_sequence(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_left_pad_times.append(end - start)
start = time.perf_counter()
pad_sequence_with_flips(sequences, batch_first=True, padding_value=False, padding_side="left")
end = time.perf_counter()
inner_flip_left_pad_times.append(end - start)
left_pad_times.append(sum(inner_left_pad_times) / len(inner_left_pad_times))
left_pad_times_std.append(np.std(inner_left_pad_times))
flip_left_pad_times.append(sum(inner_flip_left_pad_times) / len(inner_flip_left_pad_times))
flip_left_pad_times_std.append(np.std(inner_flip_left_pad_times))
print(f"Sequence Length: {sequence_length}, Left Pad Time: {left_pad_times[-1]}, Left with Flips Pad Time: {flip_left_pad_times[-1]}")
import matplotlib.pyplot as plt
plt.plot(sequence_lengths, left_pad_times, label="new pad_sequence left")
plt.scatter(sequence_lengths, left_pad_times)
plt.errorbar(sequence_lengths, left_pad_times, yerr=left_pad_times_std, linestyle='None', marker='^')
plt.plot(sequence_lengths, flip_left_pad_times, label="old pad_sequence left (2 flips)")
plt.scatter(sequence_lengths, flip_left_pad_times)
plt.errorbar(sequence_lengths, flip_left_pad_times, yerr=flip_left_pad_times_std, linestyle='None', marker='^')
plt.xlabel("Sequence Length")
plt.ylabel("Time (s)")
plt.legend(loc="upper right")
# Sequence Length: 166666, Left Pad Time: 0.06147645162009212, Left with Flips Pad Time: 0.09842291727001794
# Sequence Length: 333333, Left Pad Time: 0.08933195920990329, Left with Flips Pad Time: 0.15597836187991562
# Sequence Length: 500000, Left Pad Time: 0.08863158334006585, Left with Flips Pad Time: 0.15224887342999863
# Sequence Length: 666666, Left Pad Time: 0.10524682551997103, Left with Flips Pad Time: 0.18177212480995877
# Sequence Length: 833333, Left Pad Time: 0.11801802741003485, Left with Flips Pad Time: 0.20821274195001024
# Sequence Length: 1000000, Left Pad Time: 0.131894061660023, Left with Flips Pad Time: 0.23223503091008751
```
Co-authored-by: mskoh52 <mskoh52@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131884
Approved by: https://github.com/ezyang
Skipping importing some packages for now to make this change more
tractable.
For some reason, lintrunner on CI raises errors in all imported `.pyi` files,
even though it doesn't on my local machine. The errors are all from missing
generic types, as the MYPYINDUCTOR config has `disallow_any_generics`
set. I have thus added `disable-error-code` comments to the relevant files,
though I fixed a few that were easy enough.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113830
Approved by: https://github.com/Skylion007
ghstack dependencies: #113722, #113721
Fixes#102768
- Provides proper function declarations in generated `torch/nn/functional.pyi`.
- Moves some functions from manually defined in `functional.pyi.in` to generated code, in order to single-source the signature.
- Includes some of the functions in `torch._C._nn` into its `.pyi.in`, but not exhaustive (only what's already there).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102918
Approved by: https://github.com/drisspg, https://github.com/malfet
Changes:
- #95200
1. Recognize `.py.in` and `.pyi.in` files as Python in VS Code for a better development experience.
2. Fix deep setting merge in `tools/vscode_settings.py`.
- #95267
3. Use `Namedtuple` rather than `namedtuple + __annotations__` for `torch.nn.utils.rnn.PackedSequence_`:
`namedtuple + __annotations__`:
```python
PackedSequence_ = namedtuple('PackedSequence_',
['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
# type annotation for PackedSequence_ to make it compatible with TorchScript
PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
'sorted_indices': Optional[torch.Tensor],
'unsorted_indices': Optional[torch.Tensor]}
```
`Namedtuple`: Python 3.6+
```python
class PackedSequence_(NamedTuple):
data: torch.Tensor
batch_sizes: torch.Tensor
sorted_indices: Optional[torch.Tensor]
unsorted_indices: Optional[torch.Tensor]
```
- => this PR: #95268
4. Sort import statements and remove unnecessary imports in `.pyi`, `.pyi.in` files.
5. Format `.pyi`, `.pyi.in` files and remove unnecessary ellipsis `...` in type stubs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95268
Approved by: https://github.com/huydhn
Summary:
*Context:* https://github.com/pytorch/pytorch/issues/53406 added a lint for trailing whitespace at the ends of lines. However, in order to pass FB-internal lints, that PR also had to normalize the trailing newlines in four of the files it touched. This PR adds an OSS lint to normalize trailing newlines.
The changes to the following files (made in 54847d0adb9be71be4979cead3d9d4c02160e4cd) are the only manually-written parts of this PR:
- `.github/workflows/lint.yml`
- `mypy-strict.ini`
- `tools/README.md`
- `tools/test/test_trailing_newlines.py`
- `tools/trailing_newlines.py`
I would have liked to make this just a shell one-liner like the other three similar lints, but nothing I could find quite fit the bill. Specifically, all the answers I tried from the following Stack Overflow questions were far too slow (at least a minute and a half to run on this entire repository):
- [How to detect file ends in newline?](https://stackoverflow.com/q/38746)
- [How do I find files that do not end with a newline/linefeed?](https://stackoverflow.com/q/4631068)
- [How to list all files in the Git index without newline at end of file](https://stackoverflow.com/q/27624800)
- [Linux - check if there is an empty line at the end of a file [duplicate]](https://stackoverflow.com/q/34943632)
- [git ensure newline at end of each file](https://stackoverflow.com/q/57770972)
To avoid giving false positives during the few days after this PR is merged, we should probably only merge it after https://github.com/pytorch/pytorch/issues/54967.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54737
Test Plan:
Running the shell script from the "Ensure correct trailing newlines" step in the `quick-checks` job of `.github/workflows/lint.yml` should print no output and exit in a fraction of a second with a status of 0. That was not the case prior to this PR, as shown by this failing GHA workflow run on an earlier draft of this PR:
- https://github.com/pytorch/pytorch/runs/2197446987?check_suite_focus=true
In contrast, this run (after correcting the trailing newlines in this PR) succeeded:
- https://github.com/pytorch/pytorch/pull/54737/checks?check_run_id=2197553241
To unit-test `tools/trailing_newlines.py` itself (this is run as part of our "Test tools" GitHub Actions workflow):
```
python tools/test/test_trailing_newlines.py
```
Reviewed By: malfet
Differential Revision: D27409736
Pulled By: samestep
fbshipit-source-id: 46f565227046b39f68349bbd5633105b2d2e9b19
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38157
This removes the error prone process of assembling `torch/__init__.pyi`
(and frequently forgetting to expose things), since now we can simply
rely on the true source file to get things done. Most of the old
codegen in gen_pyi.py is now rerouted to various files:
- `torch/_C/__init__.pyi` (the dumping pile of all misc bindings)
- `torch/_C/_nn.pyi` (NN function bindings)
- `torch/_C/_VariableFunctions.pyi` (torch function bindings)
`torch.types` grew a bunch more definitions that previously where
defined in `torch/__init__.pyi`
Some miscellaneous changes
- Fixed a bug where we treat single TensorList argument as implying
varargs are accepted. This is actually only supported on IntList.
This means we can correctly generate a stub for dequantize.
- Add missing manual stub for nonzero
- Switched torch/onnx/operators.py to directly refer to _C module,
since apparently mypy doesn't think that methods prefixed with
underscores get reexported. This may be a recurring theme; maybe
we need to find a better way to solve it.
Because I was really lazy, I dumped namedtuple definitions in both
`torch._C` and `torch._C._VariableFunctions`. This is definitely wrong.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D21497400
Pulled By: ezyang
fbshipit-source-id: 07b126141c82efaca37be27c07255cb2b9b3f064