https://github.com/pytorch/pytorch/pull/129001#discussion_r1645126801 is the motivation for the whole stack of PRs. In `torch/__init__.py`, `torch._C.Type` shadows `from typing import Type`, and there is no type stub for `torch._C.Type` in `torch/_C/__init__.pyi`. So we need to use `from typing import Type as _Type`. After enabling [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585) in the `.pyi` type stub files, we can use `type` instead of `typing.Type` or `from typing import Type as _Type`.
------
- [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`.
- [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`.
Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449:
- #117449
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150727
Approved by: https://github.com/aorenste
ghstack dependencies: #150726
The default value of `rot90()` in the schema registry is `[0,1]` because we split the function schema by `", "`. There should be no space after `,` in `[0,1]`.
5c9d5272e4/aten/src/ATen/native/native_functions.yaml (L6120-L6126)
Then the the default value is formatted to `(0,1)` in `pyi` files. This PR manually adds an extra whitespace when rerendering the default value to a string.
```python
", ".join(string.split(","))
```
```python
# before
def rot90(input: Tensor, k: _int = 1, dims: _size = (0,1)) -> Tensor: ...
# after
def rot90(input: Tensor, k: _int = 1, dims: _size = (0, 1)) -> Tensor: ...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129884
Approved by: https://github.com/ezyang
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_print = torch.ops.aten._print('moo')
res = l_x_ + l_x_; l_x_ = None
_print_1 = torch.ops.aten._print('moo')
return (res,)
```
AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
def forward(self, arg1_1: "f32[2, 3]"):
_make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo'); _make_dep_token_default = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
arg1_1, = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
# Source Nodes: [_print], Original ATen: []
buf2 = aten._print.default('moo')
# Source Nodes: [_print_1], Original ATen: []
buf3 = aten._print.default('moo')
buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, buf4)
del arg1_1
return (buf4, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
* `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
* This ops is implemented on the Python side using torch.library so we can return a subclass instance
* `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
* The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
* `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
* `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
* Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)
With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.
Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
* `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
* This ops is implemented on the Python side using torch.library so we can return a subclass instance
* `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
* The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
* `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
* `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
* Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)
With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.
Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
# Motivation
This PR intends to extend `cuda_lazy_init` to `device_lazy_init` which is a device-agnostic API that can support any backend. And change `maybe_initialize_cuda` to `maybe_initialize_device` to support lazy initialization for CUDA while maintaining scalability.
# Design
We maintain a flag for each backend to manage the lazy initialization state separately.
# Additional Context
No need more UTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118846
Approved by: https://github.com/malfet
All single element list types are `Tensor[]` so they will always be Tuple.
I don't know of any way to easily access the pyi type and compare that to a real run so no testing here :(
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118238
Approved by: https://github.com/ezyang
Using mypy in code that depends on pytorch, I noticed that the type annotation doesn't allow a device ordinal.
`error: Argument "device" to "to_empty" of "Module" has incompatible type "int"; expected "str | device" [arg-type]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113647
Approved by: https://github.com/albanD
Implements a simple content-addressable store for storages (with tensors implemented as cheap references on top), enabling incremental serialization of tensors to disk, which I intend to use in the accuracy repro extractor. Check the comment at the top of torch/utils/_content_store.py for more details on the intended use case.
One major piece of this PR is implementing the content hash for tensors. For our prospective use case, we may need to repeatedly hash up to 80 GB of tensor data every time we snapshot (and we may snapshot multiple times). Using a conventional cryptographic hash and hashing each snapshot would likely take on order of minutes, which seemed too slow to me. So instead, I implemented a crappy hash function that can be run on GPU. It is at least somewhat theoretically grounded: using random parameters generated by Philox, we use the standard shift-multiply and xor sum universal hash family. The hash function is a bit dorky though; instead of properly doing 160-bit math, it just runs 32-bit hash five times and cats them together. By the way, this sets the first precedent for kernel in PyTorch library which MUST be torch.compile'd to be run (in fact, this kernel does not run in eager mode because of the use of xor_sum, which doesn't actually exist in ATen.)
I had to add a few more primitives to inductor, namely randint (over the entire int range) and xor_sum. Fortunately, these primitives are natively supported by Triton/C++, and so they were very easy to plumb through. xor_sum is exposed as a prim, while randint special cases on when low/high span the entire 32-bit signed integer range.
Thanks to Jeff Johnson for letting me bounce ideas of him on a Saturday morning lol.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99809
Approved by: https://github.com/voznesenskym
Changes:
1. Use class inheritance for `torch/return_types.pyi`:
Before:
```python
max = NamedTuple("max", [("values", Tensor), ("indices", Tensor)])
```
After:
```python
class max(NamedTuple):
values: Tensor
indices: Tensor
```
------
2. Add missing spaces in generated type annotations.
1. Always has a space after `,`.
2. If an argument is annotated, then there need spaces around `=` when it has a default value.
```diff
- def func(..., out: Optional[Tensor]=None, ...) -> Tensor:
+ def func(..., out: Optional[Tensor] = None, ...) -> Tensor:
```
3. If an argument is not annotated, then there should be no spaces around `=` when it has a default value.
```python
def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...
```
------
3. ~Remove redundant import alias in `torch/nn/functional.pyi`:~ (Reverted)
UPDATE: `mypy` needs the alias to work.
Before:
```python
from .. import conv1d as conv1d
from .. import conv2d as conv2d
from .. import conv3d as conv3d
from .. import conv_transpose1d as conv_transpose1d
from .. import conv_transpose2d as conv_transpose2d
from .. import conv_transpose3d as conv_transpose3d
from .. import conv_tbc as conv_tbc
from .. import avg_pool1d as avg_pool1d
from .. import relu_ as relu_
from .. import selu_ as selu_
from .. import celu_ as celu_
from .. import rrelu_ as rrelu_
from .. import pixel_shuffle as pixel_shuffle
from .. import pixel_unshuffle as pixel_unshuffle
from .. import channel_shuffle as channel_shuffle
from .. import native_channel_shuffle as native_channel_shuffle
from .. import pdist as pdist
from .. import cosine_similarity as cosine_similarity
```
After:
```python
from .. import (
conv1d,
conv2d,
conv3d,
conv_transpose1d,
conv_transpose2d,
conv_transpose3d,
conv_tbc,
avg_pool1d,
relu_,
selu_,
celu_,
rrelu_,
pixel_shuffle,
pixel_unshuffle,
channel_shuffle,
native_channel_shuffle,
pdist,
cosine_similarity,
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95877
Approved by: https://github.com/ezyang
Fixes#91694Fixes#92615
Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.
After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.
After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded
Funny enough, backward tests were completely disabled before and were not passing:
```python
@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32):
```
UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.
The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch. One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.
### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.
### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward
TorchCompile -> See purposed solution below.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch