Convert to markdown: named_tensor.rst, nested.rst, nn.attention.bias.rst, nn.attention.experimental.rst, nn.attention.flex_attention.rst #155028 (#155696)

Fixes #155028

This pull request updates the documentation  by transitioning from .rst to .md format. It introduces new Markdown files for the documentation of named_tensor, nested, nn.attention.bias, nn.attention.experimental, and nn.attention.flex_attention

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155696
Approved by: https://github.com/svekars

Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
This commit is contained in:
Edgar Romo Montiel
2025-06-14 03:32:00 +00:00
committed by PyTorch MergeBot
parent cdfa33a328
commit ca3cabd24a
5 changed files with 239 additions and 186 deletions

View File

@ -1,9 +1,10 @@
```{eval-rst}
.. currentmodule:: torch
```
.. _named_tensors-doc:
(named_tensors-doc)=
Named Tensors
=============
# Named Tensors
Named Tensors allow users to give explicit names to tensor dimensions.
In most cases, operations that take dimension parameters will accept
@ -14,43 +15,42 @@ also be used to rearrange dimensions, for example, to support
"broadcasting by name" rather than "broadcasting by position".
.. warning::
```{warning}
The named tensor API is a prototype feature and subject to change.
```
Creating named tensors
----------------------
## Creating named tensors
Factory functions now take a new :attr:`names` argument that associates a name
Factory functions now take a new {attr}`names` argument that associates a name
with each dimension.
::
```
>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
[0., 0., 0.]], names=('N', 'C'))
```
Named dimensions, like regular Tensor dimensions, are ordered.
``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``.
The following factory functions support named tensors:
- :func:`torch.empty`
- :func:`torch.rand`
- :func:`torch.randn`
- :func:`torch.ones`
- :func:`torch.tensor`
- :func:`torch.zeros`
- {func}`torch.empty`
- {func}`torch.rand`
- {func}`torch.randn`
- {func}`torch.ones`
- {func}`torch.tensor`
- {func}`torch.zeros`
Named dimensions
----------------
## Named dimensions
See :attr:`~Tensor.names` for restrictions on tensor names.
See {attr}`~Tensor.names` for restrictions on tensor names.
Use :attr:`~Tensor.names` to access the dimension names of a tensor and
:meth:`~Tensor.rename` to rename named dimensions.
::
Use {attr}`~Tensor.names` to access the dimension names of a tensor and
{meth}`~Tensor.rename` to rename named dimensions.
```
>>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
>>> imgs.names
('N', 'C', 'H', 'W')
@ -58,20 +58,19 @@ Use :attr:`~Tensor.names` to access the dimension names of a tensor and
>>> renamed_imgs = imgs.rename(H='height', W='width')
>>> renamed_imgs.names
('N', 'C', 'height', 'width)
```
Named tensors can coexist with unnamed tensors; named tensors are instances of
:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
{class}`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
tensors do not require all dimensions to be named.
::
```
>>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W'))
>>> imgs.names
(None, 'C', 'H', 'W')
```
Name propagation semantics
--------------------------
## Name propagation semantics
Named tensors use names to automatically check that APIs are being called
correctly at runtime. This occurs in a process called *name inference*.
@ -83,17 +82,16 @@ More formally, name inference consists of the following two steps:
All operations that support named tensors propagate names.
::
```
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')
```
.. _match_semantics-doc:
(match_semantics-doc)=
### match semantics
match semantics
^^^^^^^^^^^^^^^
Two names *match* if they are equal (string equality) or if at least one is ``None``.
Nones are essentially a special "wildcard" name.
@ -102,62 +100,59 @@ Nones are essentially a special "wildcard" name.
It returns the more *specific* of the two names, if they match. If the names do not match,
then it errors.
.. note::
```{note}
In practice, when working with named tensors, one should avoid having unnamed
dimensions because their handling can be complicated. It is recommended to lift
all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`.
all unnamed dimensions to be named dimensions by using {meth}`~Tensor.refine_names`.
```
Basic name inference rules
^^^^^^^^^^^^^^^^^^^^^^^^^^
### Basic name inference rules
Let's see how ``match`` and ``unify`` are used in name inference in the case of
adding two one-dim tensors with no broadcasting.
::
```
x = torch.randn(3, names=('X',))
y = torch.randn(3)
z = torch.randn(3, names=('Z',))
```
**Check names**: check that the names of the two tensors *match*.
For the following examples:
::
```
>>> # x + y # match('X', None) is True
>>> # x + z # match('X', 'Z') is False
>>> # x + x # match('X', 'X') is True
>>> x + z
Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match.
```
**Propagate names**: *unify* the names to select which one to propagate.
In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more
specific than ``None``.
::
```
>>> (x + y).names
('X',)
>>> (x + x).names
('X',)
```
For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`.
For a comprehensive list of name inference rules, see {ref}`name_inference_reference-doc`.
Here are two common operations that may be useful to go over:
- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc`
- Matrix multiplication ops: :ref:`contracts_away_dims-doc`
- Binary arithmetic ops: {ref}`unifies_names_from_inputs-doc`
- Matrix multiplication ops: {ref}`contracts_away_dims-doc`
Explicit alignment by names
---------------------------
## Explicit alignment by names
Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions
Use {meth}`~Tensor.align_as` or {meth}`~Tensor.align_to` to align tensor dimensions
by name to a specified ordering. This is useful for performing "broadcasting by names".
::
```
# This function is agnostic to the dimension ordering of `input`,
# as long as it has a `C` dimension somewhere.
def scale_channels(input, scale):
@ -173,15 +168,14 @@ by name to a specified ordering. This is useful for performing "broadcasting by
>>> scale_channels(imgs, scale)
>>> scale_channels(more_imgs, scale)
>>> scale_channels(videos, scale)
```
Manipulating dimensions
-----------------------
## Manipulating dimensions
Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without
mentioning all of them as in required by :meth:`~Tensor.permute`.
::
Use {meth}`~Tensor.align_to` to permute large amounts of dimensions without
mentioning all of them as in required by {meth}`~Tensor.permute`.
```
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
@ -189,13 +183,14 @@ mentioning all of them as in required by :meth:`~Tensor.permute`.
# the rest in the same order
>>> tensor.permute(5, 4, 0, 1, 2, 3)
>>> named_tensor.align_to('F', 'E', ...)
```
Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten
dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view`
and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
Use {meth}`~Tensor.flatten` and {meth}`~Tensor.unflatten` to flatten and unflatten
dimensions, respectively. These methods are more verbose than {meth}`~Tensor.view`
and {meth}`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
::
```
>>> imgs = torch.randn(32, 3, 128, 128)
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
@ -207,18 +202,16 @@ and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading t
>>> unflattened_named_imgs = named_flat_imgs.unflatten('features', [('C', 3), ('H', 128), ('W', 128)])
>>> unflattened_named_imgs.names
('N', 'C', 'H', 'W')
```
.. _named_tensors_autograd-doc:
Autograd support
----------------
(named_tensors_autograd-doc)=
## Autograd support
Autograd currently supports named tensors in a limited manner: autograd ignores
names on all tensors. Gradient computation is still correct but we lose the
safety that names give us.
::
```
>>> x = torch.randn(3, names=('D',))
>>> weight = torch.randn(3, names=('D',), requires_grad=True)
>>> loss = (x - weight).abs()
@ -234,31 +227,30 @@ safety that names give us.
>>> loss.backward(grad_loss)
>>> weight.grad
tensor([-1.8107, -0.6357, 0.0783])
```
Currently supported operations and subsystems
---------------------------------------------
## Currently supported operations and subsystems
Operators
^^^^^^^^^
### Operators
See :ref:`name_inference_reference-doc` for a full list of the supported torch and
See {ref}`name_inference_reference-doc` for a full list of the supported torch and
tensor operations. We do not yet support the following that is not covered by the link:
- indexing, advanced indexing.
For ``torch.nn.functional`` operators, we support the following:
- :func:`torch.nn.functional.relu`
- :func:`torch.nn.functional.softmax`
- :func:`torch.nn.functional.log_softmax`
- :func:`torch.nn.functional.tanh`
- :func:`torch.nn.functional.sigmoid`
- :func:`torch.nn.functional.dropout`
- {func}`torch.nn.functional.relu`
- {func}`torch.nn.functional.softmax`
- {func}`torch.nn.functional.log_softmax`
- {func}`torch.nn.functional.tanh`
- {func}`torch.nn.functional.sigmoid`
- {func}`torch.nn.functional.dropout`
Subsystems
^^^^^^^^^^
### Subsystems
Autograd is supported, see :ref:`named_tensors_autograd-doc`.
Autograd is supported, see {ref}`named_tensors_autograd-doc`.
Because gradients are currently unnamed, optimizers may work but are untested.
NN modules are currently unsupported. This can lead to the following when calling
@ -272,32 +264,36 @@ We also do not support the following subsystems, though some may work out
of the box:
- distributions
- serialization (:func:`torch.load`, :func:`torch.save`)
- serialization ({func}`torch.load`, {func}`torch.save`)
- multiprocessing
- JIT
- distributed
- ONNX
If any of these would help your use case, please
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_
and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
[search if an issue has already been filed](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22)
and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose).
Named tensor API reference
--------------------------
## Named tensor API reference
In this section please find the documentation for named tensor specific APIs.
For a comprehensive reference for how names are propagated through other PyTorch
operators, see :ref:`name_inference_reference-doc`.
operators, see {ref}`name_inference_reference-doc`.
```{eval-rst}
.. class:: Tensor()
:noindex:
.. autoattribute:: names
.. automethod:: rename
.. automethod:: rename_
.. automethod:: refine_names
.. automethod:: align_as
.. automethod:: align_to
.. py:method:: flatten(dims, out_dim) -> Tensor
@ -317,3 +313,4 @@ operators, see :ref:`name_inference_reference-doc`.
.. warning::
The named tensor API is experimental and subject to change.
```

View File

@ -1,14 +1,15 @@
torch.nested
============
# torch.nested
```{eval-rst}
.. automodule:: torch.nested
```
Introduction
++++++++++++
## Introduction
.. warning::
```{warning}
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
```
Nested tensors allow for ragged-shaped data to be contained within and operated upon as a
single tensor. Such data is stored underneath in an efficient packed representation, while exposing
@ -22,19 +23,17 @@ padding. This is inefficient and error-prone, and nested tensors exist to addres
The API for calling operations on a nested tensor is no different from that of a regular
``torch.Tensor``, allowing for seamless integration with existing models, with the main
difference being :ref:`construction of the inputs <construction>`.
difference being {ref}`construction of the inputs <construction>`.
As this is a prototype feature, the set of :ref:`operations supported <supported operations>` is
As this is a prototype feature, the set of {ref}`operations supported <supported operations>` is
limited, but growing. We welcome issues, feature requests, and contributions.
More information on contributing can be found
`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
[in this Readme](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md).
.. _construction:
(construction)=
## Construction
Construction
++++++++++++
.. note::
```{note}
There are two forms of nested tensors present within PyTorch, distinguished by layout as
specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``.
@ -42,6 +41,7 @@ Construction
supports a single ragged dimension, it has better op coverage, receives active development, and
integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to
nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout.
```
Construction is straightforward and involves passing a list of tensors to the
``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout
@ -49,6 +49,7 @@ Construction is straightforward and involves passing a list of tensors to the
into a packed, contiguous block of memory according to the layout described in the `data_layout`_
section below.
```
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
@ -57,31 +58,36 @@ tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> print([component for component in nt])
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
```
Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary
along a single dimension. If the dimensionalities of the input components don't match, the
constructor throws an error.
```
>>> a = torch.randn(50, 128) # 2D tensor
>>> b = torch.randn(2, 50, 128) # 3D tensor
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
...
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
```
During construction, dtype, device, and whether gradients are required can be chosen via the
usual keyword arguments.
```
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
>>> print([component for component in nt])
[tensor([0., 1., 2.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
```
``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed
to the constructor. When this constructor is utilized, gradients will flow through the nested tensor
back into the original components. Note that this constructor still copies the input components into
a packed, contiguous block of memory.
```
>>> a = torch.randn(12, 512, requires_grad=True)
>>> b = torch.randn(23, 512, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -102,6 +108,7 @@ tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
```
The above functions all create contiguous NJTs, where a chunk of memory is allocated to store
a packed form of the underlying components (see the `data_layout`_ section below for more
@ -111,6 +118,7 @@ It is also possible to create a non-contiguous NJT view over a pre-existing dens
with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool
for accomplishing this.
```
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
@ -118,26 +126,26 @@ for accomplishing this.
torch.Size([3, j1, 4])
>>> nt.is_contiguous()
False
```
Note that the nested tensor acts as a view over the original padded dense tensor, referencing the
same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more
limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT
using ``contiguous()``.
.. _data_layout:
Data Layout and Shape
+++++++++++++++++++++
(data_layout)=
## Data Layout and Shape
For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of
memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged``
layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets``
component delineating batch item boundaries for the ragged dimension.
.. image:: _static/img/nested/njt_visual.png
![image](_static/img/nested/njt_visual.png)
It's possible to directly access the underlying NJT components when necessary.
```
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -145,19 +153,23 @@ It's possible to directly access the underlying NJT components when necessary.
torch.Size([82, 128])
>>> nt.offsets()
tensor([ 0, 50, 82])
```
It can also be useful to construct an NJT from the jagged ``values`` and ``offsets``
constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves
this purpose.
```
>>> values = torch.randn(82, 128)
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
```
An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The
underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the
example below).
```
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -165,6 +177,7 @@ example below).
3
>>> nt.shape
torch.Size([2, j1, 128])
```
NJTs must have the same ragged structure to be compatible with each other. For example, to run a
binary operation involving two NJTs, the ragged structures must match (i.e. they must have the
@ -172,6 +185,7 @@ same ragged shape symbol in their shapes). In the details, each symbol correspon
``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with
each other.
```
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -180,21 +194,20 @@ each other.
False
>>> nt3 = nt1 + nt2
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
```
In the above example, even though the conceptual shapes of the two NJTs are the same, they don't
share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not
compatible. We recognize that this behavior is unintuitive and are working hard to relax this
restriction for the beta release of nested tensors. For a workaround, see the
:ref:`Troubleshooting <ragged_structure_incompatibility>` section of this document.
{ref}`Troubleshooting <ragged_structure_incompatibility>` section of this document.
In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum
sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA).
There are currently no public APIs for accessing these, but this will change for the beta release.
.. _supported operations:
Supported Operations
++++++++++++++++++++
(supported operations)=
## Supported Operations
This section contains a list of common operations over nested tensors that you may find useful.
It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While
@ -203,15 +216,15 @@ The ideal state for nested tensors is full support of all PyTorch operations tha
for non-nested tensors. To help us accomplish this, please consider:
* Requesting particular ops needed for your use case
`here <https://github.com/pytorch/pytorch/issues/118107>`__ to help us prioritize.
[here](https://github.com/pytorch/pytorch/issues/118107) to help us prioritize.
* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see
the `Contributions <contributions>`__ section below for details.
the [Contributions](contributions) section below for details.
Viewing nested tensor constituents
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
### Viewing nested tensor constituents
``unbind()`` allows you to retrieve a view of the nested tensor's constituents.
```
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 3)
@ -231,16 +244,17 @@ tensor([[ 3.6858, -3.7030, -4.4525],
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
```
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which
represents the first entry or constituent of the nested tensor.
Conversions to / from padded
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#### Conversions to / from padded
``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified
padding value. The ragged dimension will be padded out to the size of the maximum sequence length.
```
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(6, 3)
@ -259,6 +273,7 @@ tensor([[[ 1.6107, 0.5723, 0.3913],
[-0.5506, 0.7608, 2.0606],
[ 1.5658, -1.1934, 0.3041],
[ 0.1483, -1.1284, 0.6957]]])
```
This can be useful as an escape hatch to work around NJT support gaps, but ideally such
conversions should be avoided when possible for optimal memory usage and performance, as the
@ -269,6 +284,7 @@ ragged structure to a given dense tensor to produce an NJT. Note that by default
does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be
useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired.
```
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
@ -277,12 +293,13 @@ torch.Size([3, j1, 4])
>>> nt = nt.contiguous()
>>> nt.shape
torch.Size([3, j2, 4])
```
Shape manipulations
^^^^^^^^^^^^^^^^^^^
### Shape manipulations
Nested tensors support a wide array of operations for shape manipulation, including views.
```
>>> a = torch.randn(2, 6)
>>> b = torch.randn(4, 6)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
@ -298,36 +315,37 @@ torch.Size([2, j1, 12])
torch.Size([2, j1, 2, 6])
>>> nt.transpose(-1, -2).shape
torch.Size([2, 6, j1])
```
Attention mechanisms
^^^^^^^^^^^^^^^^^^^^
### Attention mechanisms
As variable-length sequences are common inputs to attention mechanisms, nested tensors support
important attention operators
`Scaled Dot Product Attention (SDPA) <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_ and
`FlexAttention <https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention>`_.
[Scaled Dot Product Attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and
[FlexAttention](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention).
See
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention>`__
[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention)
for usage examples of NJT with SDPA and
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt>`__
[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt)
for usage examples of NJT with FlexAttention.
.. _usage_with_torch_compile:
(usage_with_torch_compile)=
## Usage with torch.compile
Usage with torch.compile
++++++++++++++++++++++++
NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always
recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and
graph-break-free both when passed as inputs to a compiled function or module OR when
instantiated in-line within the function.
.. note::
```{note}
If you're not able to utilize ``torch.compile()`` for your use case, performance and memory
usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be
the case. It is important that the tensors being operated on are large enough so the
performance gains are not outweighed by the overhead of python tensor subclasses.
```
```
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
@ -344,11 +362,13 @@ torch.Size([2, j1, 3])
>>> output2 = compiled_g(nt.values(), nt.offsets())
>>> output2.shape
torch.Size([2, j1, 3])
```
Note that NJTs support
`Dynamic Shapes <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_
[Dynamic Shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html)
to avoid unnecessary recompiles with changing ragged structure.
```
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> c = torch.randn(5, 3)
@ -360,44 +380,39 @@ to avoid unnecessary recompiles with changing ragged structure.
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output1 = compiled_f(nt1)
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
```
If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a
PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may
be some rough edges at this time.
.. _troubleshooting:
Troubleshooting
+++++++++++++++
(troubleshooting)=
## Troubleshooting
This section contains common errors that you may run into when utilizing nested tensors, alongside
the reason for these errors and suggestions for how to address them.
.. _unimplemented_op:
Unimplemented ops
^^^^^^^^^^^^^^^^^
(unimplemented_op)=
### Unimplemented ops
This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it
today given that there are a couple thousand ops within PyTorch.
::
```
NotImplementedError: aten.view_as_real.default
```
The error is straightforward; we haven't gotten around to adding op support for this particular op
yet. If you'd like, you can `contribute <contributions>`__ an implementation yourself OR simply
`request <https://github.com/pytorch/pytorch/issues/118107>`_ that we add support for this op
yet. If you'd like, you can [contribute](contributions) an implementation yourself OR simply
[request](https://github.com/pytorch/pytorch/issues/118107) that we add support for this op
in a future PyTorch release.
.. _ragged_structure_incompatibility:
Ragged structure incompatibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
(ragged_structure_incompatibility)=
### Ragged structure incompatibility
```
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
```
This error occurs when calling an op that operates over multiple NJTs with incompatible ragged
structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent
@ -407,6 +422,7 @@ As a workaround for this situation, it is possible to construct NJTs from the ``
``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they
are considered to have the same ragged structure and are thus compatible.
```
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -414,18 +430,19 @@ are considered to have the same ragged structure and are thus compatible.
>>> nt3 = nt1 + nt2
>>> nt3.shape
torch.Size([2, j1, 128])
```
Data dependent operation within torch.compile
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
### Data dependent operation within torch.compile
```
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
```
This error occurs when calling an op that does data-dependent operation within torch.compile; this
commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the
output shape. For example:
```
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
@ -433,30 +450,31 @@ output shape. For example:
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
```
In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the
NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a
workaround, there are a couple torch.compile flags that can be set:
```
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
>>> torch._dynamo.config.capture_scalar_outputs = True
```
If, after setting these, you still see data-dependent operator errors, please file an issue with
PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of
NJT support may be incomplete.
.. _contributions:
Contributions
+++++++++++++
(contributions)=
## Contributions
If you'd like to contribute to nested tensor development, one of the most impactful ways to do
so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally
consists of a couple simple steps:
#. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
1. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``.
#. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
2. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
established there for other ops. Use the signature from ``native_functions.yaml`` for schema
validation.
@ -474,20 +492,32 @@ working implementation:
Within ``torch.compile``, these conversions can be fused to avoid materializing the padded
intermediate.
.. _construction_and_conversion:
Detailed Docs for Construction and Conversion Functions
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
(construction_and_conversion)=
## Detailed Docs for Construction and Conversion Functions
```{eval-rst}
.. currentmodule:: torch.nested
```
```{eval-rst}
.. autofunction:: nested_tensor
```
```{eval-rst}
.. autofunction:: nested_tensor_from_jagged
```
```{eval-rst}
.. autofunction:: as_nested_tensor
```
```{eval-rst}
.. autofunction:: to_padded_tensor
```
```{eval-rst}
.. autofunction:: masked_select
```
```{eval-rst}
.. autofunction:: narrow
```
```{eval-rst}
.. seealso::
`Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile <https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
```

View File

@ -1,23 +1,25 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# torch.nn.attention.bias
torch.nn.attention.bias
========================
```{eval-rst}
.. automodule:: torch.nn.attention.bias
.. currentmodule:: torch.nn.attention.bias
```
CausalBias
==========
## CausalBias
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classnoinheritance.rst
CausalBias
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
@ -25,3 +27,4 @@ CausalBias
causal_lower_right
causal_upper_left
CausalVariant
```

View File

@ -1,7 +1,12 @@
torch.nn.attention.experimental
===============================
.. currentmodule:: torch.nn.attention.experimental
.. py:module:: torch.nn.attention.experimental
# torch.nn.attention.experimental
.. warning::
```{eval-rst}
.. currentmodule:: torch.nn.attention.experimental
```
```{eval-rst}
.. py:module:: torch.nn.attention.experimental
```
```{warning}
These APIs are experimental and subject to change without notice.
```

View File

@ -1,27 +1,45 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
======================================
torch.nn.attention.flex_attention
======================================
# torch.nn.attention.flex_attention
```{eval-rst}
.. currentmodule:: torch.nn.attention.flex_attention
```
```{eval-rst}
.. py:module:: torch.nn.attention.flex_attention
```
```{eval-rst}
.. autofunction:: flex_attention
```
BlockMask Utilities
-------------------
## BlockMask Utilities
```{eval-rst}
.. autofunction:: create_block_mask
```
```{eval-rst}
.. autofunction:: create_mask
```
```{eval-rst}
.. autofunction:: create_nested_block_mask
```
```{eval-rst}
.. autofunction:: and_masks
```
```{eval-rst}
.. autofunction:: or_masks
```
```{eval-rst}
.. autofunction:: noop_mask
```
BlockMask
---------
## BlockMask
```{eval-rst}
.. autoclass:: BlockMask
:members:
:undoc-members:
```