mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Convert to markdown: named_tensor.rst, nested.rst, nn.attention.bias.rst, nn.attention.experimental.rst, nn.attention.flex_attention.rst #155028 (#155696)
Fixes #155028 This pull request updates the documentation by transitioning from .rst to .md format. It introduces new Markdown files for the documentation of named_tensor, nested, nn.attention.bias, nn.attention.experimental, and nn.attention.flex_attention Pull Request resolved: https://github.com/pytorch/pytorch/pull/155696 Approved by: https://github.com/svekars Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
cdfa33a328
commit
ca3cabd24a
@ -1,9 +1,10 @@
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch
|
||||
```
|
||||
|
||||
.. _named_tensors-doc:
|
||||
(named_tensors-doc)=
|
||||
|
||||
Named Tensors
|
||||
=============
|
||||
# Named Tensors
|
||||
|
||||
Named Tensors allow users to give explicit names to tensor dimensions.
|
||||
In most cases, operations that take dimension parameters will accept
|
||||
@ -14,43 +15,42 @@ also be used to rearrange dimensions, for example, to support
|
||||
"broadcasting by name" rather than "broadcasting by position".
|
||||
|
||||
|
||||
.. warning::
|
||||
```{warning}
|
||||
The named tensor API is a prototype feature and subject to change.
|
||||
```
|
||||
|
||||
Creating named tensors
|
||||
----------------------
|
||||
## Creating named tensors
|
||||
|
||||
Factory functions now take a new :attr:`names` argument that associates a name
|
||||
|
||||
Factory functions now take a new {attr}`names` argument that associates a name
|
||||
with each dimension.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> torch.zeros(2, 3, names=('N', 'C'))
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]], names=('N', 'C'))
|
||||
```
|
||||
|
||||
Named dimensions, like regular Tensor dimensions, are ordered.
|
||||
``tensor.names[i]`` is the name of dimension ``i`` of ``tensor``.
|
||||
|
||||
The following factory functions support named tensors:
|
||||
|
||||
- :func:`torch.empty`
|
||||
- :func:`torch.rand`
|
||||
- :func:`torch.randn`
|
||||
- :func:`torch.ones`
|
||||
- :func:`torch.tensor`
|
||||
- :func:`torch.zeros`
|
||||
- {func}`torch.empty`
|
||||
- {func}`torch.rand`
|
||||
- {func}`torch.randn`
|
||||
- {func}`torch.ones`
|
||||
- {func}`torch.tensor`
|
||||
- {func}`torch.zeros`
|
||||
|
||||
Named dimensions
|
||||
----------------
|
||||
## Named dimensions
|
||||
|
||||
See :attr:`~Tensor.names` for restrictions on tensor names.
|
||||
See {attr}`~Tensor.names` for restrictions on tensor names.
|
||||
|
||||
Use :attr:`~Tensor.names` to access the dimension names of a tensor and
|
||||
:meth:`~Tensor.rename` to rename named dimensions.
|
||||
|
||||
::
|
||||
Use {attr}`~Tensor.names` to access the dimension names of a tensor and
|
||||
{meth}`~Tensor.rename` to rename named dimensions.
|
||||
|
||||
```
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=('N', 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
('N', 'C', 'H', 'W')
|
||||
@ -58,20 +58,19 @@ Use :attr:`~Tensor.names` to access the dimension names of a tensor and
|
||||
>>> renamed_imgs = imgs.rename(H='height', W='width')
|
||||
>>> renamed_imgs.names
|
||||
('N', 'C', 'height', 'width)
|
||||
|
||||
```
|
||||
|
||||
Named tensors can coexist with unnamed tensors; named tensors are instances of
|
||||
:class:`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
|
||||
{class}`torch.Tensor`. Unnamed tensors have ``None``-named dimensions. Named
|
||||
tensors do not require all dimensions to be named.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> imgs = torch.randn(1, 2, 2, 3 , names=(None, 'C', 'H', 'W'))
|
||||
>>> imgs.names
|
||||
(None, 'C', 'H', 'W')
|
||||
```
|
||||
|
||||
Name propagation semantics
|
||||
--------------------------
|
||||
## Name propagation semantics
|
||||
|
||||
Named tensors use names to automatically check that APIs are being called
|
||||
correctly at runtime. This occurs in a process called *name inference*.
|
||||
@ -83,17 +82,16 @@ More formally, name inference consists of the following two steps:
|
||||
|
||||
All operations that support named tensors propagate names.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> x = torch.randn(3, 3, names=('N', 'C'))
|
||||
>>> x.abs().names
|
||||
('N', 'C')
|
||||
```
|
||||
|
||||
|
||||
.. _match_semantics-doc:
|
||||
(match_semantics-doc)=
|
||||
### match semantics
|
||||
|
||||
match semantics
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Two names *match* if they are equal (string equality) or if at least one is ``None``.
|
||||
Nones are essentially a special "wildcard" name.
|
||||
@ -102,62 +100,59 @@ Nones are essentially a special "wildcard" name.
|
||||
It returns the more *specific* of the two names, if they match. If the names do not match,
|
||||
then it errors.
|
||||
|
||||
.. note::
|
||||
In practice, when working with named tensors, one should avoid having unnamed
|
||||
dimensions because their handling can be complicated. It is recommended to lift
|
||||
all unnamed dimensions to be named dimensions by using :meth:`~Tensor.refine_names`.
|
||||
```{note}
|
||||
In practice, when working with named tensors, one should avoid having unnamed
|
||||
dimensions because their handling can be complicated. It is recommended to lift
|
||||
all unnamed dimensions to be named dimensions by using {meth}`~Tensor.refine_names`.
|
||||
```
|
||||
|
||||
|
||||
Basic name inference rules
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Basic name inference rules
|
||||
|
||||
Let's see how ``match`` and ``unify`` are used in name inference in the case of
|
||||
adding two one-dim tensors with no broadcasting.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
x = torch.randn(3, names=('X',))
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3, names=('Z',))
|
||||
```
|
||||
|
||||
**Check names**: check that the names of the two tensors *match*.
|
||||
|
||||
For the following examples:
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> # x + y # match('X', None) is True
|
||||
>>> # x + z # match('X', 'Z') is False
|
||||
>>> # x + x # match('X', 'X') is True
|
||||
|
||||
>>> x + z
|
||||
Error when attempting to broadcast dims ['X'] and dims ['Z']: dim 'X' and dim 'Z' are at the same position from the right but do not match.
|
||||
```
|
||||
|
||||
**Propagate names**: *unify* the names to select which one to propagate.
|
||||
In the case of ``x + y``, ``unify('X', None) = 'X'`` because ``'X'`` is more
|
||||
specific than ``None``.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> (x + y).names
|
||||
('X',)
|
||||
>>> (x + x).names
|
||||
('X',)
|
||||
```
|
||||
|
||||
For a comprehensive list of name inference rules, see :ref:`name_inference_reference-doc`.
|
||||
For a comprehensive list of name inference rules, see {ref}`name_inference_reference-doc`.
|
||||
Here are two common operations that may be useful to go over:
|
||||
|
||||
- Binary arithmetic ops: :ref:`unifies_names_from_inputs-doc`
|
||||
- Matrix multiplication ops: :ref:`contracts_away_dims-doc`
|
||||
- Binary arithmetic ops: {ref}`unifies_names_from_inputs-doc`
|
||||
- Matrix multiplication ops: {ref}`contracts_away_dims-doc`
|
||||
|
||||
Explicit alignment by names
|
||||
---------------------------
|
||||
## Explicit alignment by names
|
||||
|
||||
Use :meth:`~Tensor.align_as` or :meth:`~Tensor.align_to` to align tensor dimensions
|
||||
Use {meth}`~Tensor.align_as` or {meth}`~Tensor.align_to` to align tensor dimensions
|
||||
by name to a specified ordering. This is useful for performing "broadcasting by names".
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
# This function is agnostic to the dimension ordering of `input`,
|
||||
# as long as it has a `C` dimension somewhere.
|
||||
def scale_channels(input, scale):
|
||||
@ -173,15 +168,14 @@ by name to a specified ordering. This is useful for performing "broadcasting by
|
||||
>>> scale_channels(imgs, scale)
|
||||
>>> scale_channels(more_imgs, scale)
|
||||
>>> scale_channels(videos, scale)
|
||||
```
|
||||
|
||||
Manipulating dimensions
|
||||
-----------------------
|
||||
## Manipulating dimensions
|
||||
|
||||
Use :meth:`~Tensor.align_to` to permute large amounts of dimensions without
|
||||
mentioning all of them as in required by :meth:`~Tensor.permute`.
|
||||
|
||||
::
|
||||
Use {meth}`~Tensor.align_to` to permute large amounts of dimensions without
|
||||
mentioning all of them as in required by {meth}`~Tensor.permute`.
|
||||
|
||||
```
|
||||
>>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
|
||||
>>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
|
||||
|
||||
@ -189,13 +183,14 @@ mentioning all of them as in required by :meth:`~Tensor.permute`.
|
||||
# the rest in the same order
|
||||
>>> tensor.permute(5, 4, 0, 1, 2, 3)
|
||||
>>> named_tensor.align_to('F', 'E', ...)
|
||||
```
|
||||
|
||||
Use :meth:`~Tensor.flatten` and :meth:`~Tensor.unflatten` to flatten and unflatten
|
||||
dimensions, respectively. These methods are more verbose than :meth:`~Tensor.view`
|
||||
and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
|
||||
Use {meth}`~Tensor.flatten` and {meth}`~Tensor.unflatten` to flatten and unflatten
|
||||
dimensions, respectively. These methods are more verbose than {meth}`~Tensor.view`
|
||||
and {meth}`~Tensor.reshape`, but have more semantic meaning to someone reading the code.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> imgs = torch.randn(32, 3, 128, 128)
|
||||
>>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
|
||||
|
||||
@ -207,18 +202,16 @@ and :meth:`~Tensor.reshape`, but have more semantic meaning to someone reading t
|
||||
>>> unflattened_named_imgs = named_flat_imgs.unflatten('features', [('C', 3), ('H', 128), ('W', 128)])
|
||||
>>> unflattened_named_imgs.names
|
||||
('N', 'C', 'H', 'W')
|
||||
```
|
||||
|
||||
.. _named_tensors_autograd-doc:
|
||||
|
||||
Autograd support
|
||||
----------------
|
||||
(named_tensors_autograd-doc)=
|
||||
## Autograd support
|
||||
|
||||
Autograd currently supports named tensors in a limited manner: autograd ignores
|
||||
names on all tensors. Gradient computation is still correct but we lose the
|
||||
safety that names give us.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
>>> x = torch.randn(3, names=('D',))
|
||||
>>> weight = torch.randn(3, names=('D',), requires_grad=True)
|
||||
>>> loss = (x - weight).abs()
|
||||
@ -234,31 +227,30 @@ safety that names give us.
|
||||
>>> loss.backward(grad_loss)
|
||||
>>> weight.grad
|
||||
tensor([-1.8107, -0.6357, 0.0783])
|
||||
```
|
||||
|
||||
Currently supported operations and subsystems
|
||||
---------------------------------------------
|
||||
## Currently supported operations and subsystems
|
||||
|
||||
Operators
|
||||
^^^^^^^^^
|
||||
### Operators
|
||||
|
||||
See :ref:`name_inference_reference-doc` for a full list of the supported torch and
|
||||
See {ref}`name_inference_reference-doc` for a full list of the supported torch and
|
||||
tensor operations. We do not yet support the following that is not covered by the link:
|
||||
|
||||
- indexing, advanced indexing.
|
||||
|
||||
For ``torch.nn.functional`` operators, we support the following:
|
||||
|
||||
- :func:`torch.nn.functional.relu`
|
||||
- :func:`torch.nn.functional.softmax`
|
||||
- :func:`torch.nn.functional.log_softmax`
|
||||
- :func:`torch.nn.functional.tanh`
|
||||
- :func:`torch.nn.functional.sigmoid`
|
||||
- :func:`torch.nn.functional.dropout`
|
||||
- {func}`torch.nn.functional.relu`
|
||||
- {func}`torch.nn.functional.softmax`
|
||||
- {func}`torch.nn.functional.log_softmax`
|
||||
- {func}`torch.nn.functional.tanh`
|
||||
- {func}`torch.nn.functional.sigmoid`
|
||||
- {func}`torch.nn.functional.dropout`
|
||||
|
||||
Subsystems
|
||||
^^^^^^^^^^
|
||||
### Subsystems
|
||||
|
||||
Autograd is supported, see :ref:`named_tensors_autograd-doc`.
|
||||
|
||||
Autograd is supported, see {ref}`named_tensors_autograd-doc`.
|
||||
Because gradients are currently unnamed, optimizers may work but are untested.
|
||||
|
||||
NN modules are currently unsupported. This can lead to the following when calling
|
||||
@ -272,32 +264,36 @@ We also do not support the following subsystems, though some may work out
|
||||
of the box:
|
||||
|
||||
- distributions
|
||||
- serialization (:func:`torch.load`, :func:`torch.save`)
|
||||
- serialization ({func}`torch.load`, {func}`torch.save`)
|
||||
- multiprocessing
|
||||
- JIT
|
||||
- distributed
|
||||
- ONNX
|
||||
|
||||
If any of these would help your use case, please
|
||||
`search if an issue has already been filed <https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22>`_
|
||||
and if not, `file one <https://github.com/pytorch/pytorch/issues/new/choose>`_.
|
||||
[search if an issue has already been filed](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+named+tensor%22)
|
||||
and if not, [file one](https://github.com/pytorch/pytorch/issues/new/choose).
|
||||
|
||||
Named tensor API reference
|
||||
--------------------------
|
||||
## Named tensor API reference
|
||||
|
||||
In this section please find the documentation for named tensor specific APIs.
|
||||
For a comprehensive reference for how names are propagated through other PyTorch
|
||||
operators, see :ref:`name_inference_reference-doc`.
|
||||
operators, see {ref}`name_inference_reference-doc`.
|
||||
|
||||
```{eval-rst}
|
||||
.. class:: Tensor()
|
||||
:noindex:
|
||||
|
||||
.. autoattribute:: names
|
||||
|
||||
.. automethod:: rename
|
||||
|
||||
.. automethod:: rename_
|
||||
|
||||
.. automethod:: refine_names
|
||||
|
||||
.. automethod:: align_as
|
||||
|
||||
.. automethod:: align_to
|
||||
|
||||
.. py:method:: flatten(dims, out_dim) -> Tensor
|
||||
@ -317,3 +313,4 @@ operators, see :ref:`name_inference_reference-doc`.
|
||||
|
||||
.. warning::
|
||||
The named tensor API is experimental and subject to change.
|
||||
```
|
@ -1,14 +1,15 @@
|
||||
torch.nested
|
||||
============
|
||||
# torch.nested
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.nested
|
||||
```
|
||||
|
||||
Introduction
|
||||
++++++++++++
|
||||
## Introduction
|
||||
|
||||
.. warning::
|
||||
|
||||
```{warning}
|
||||
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
|
||||
```
|
||||
|
||||
Nested tensors allow for ragged-shaped data to be contained within and operated upon as a
|
||||
single tensor. Such data is stored underneath in an efficient packed representation, while exposing
|
||||
@ -22,19 +23,17 @@ padding. This is inefficient and error-prone, and nested tensors exist to addres
|
||||
|
||||
The API for calling operations on a nested tensor is no different from that of a regular
|
||||
``torch.Tensor``, allowing for seamless integration with existing models, with the main
|
||||
difference being :ref:`construction of the inputs <construction>`.
|
||||
difference being {ref}`construction of the inputs <construction>`.
|
||||
|
||||
As this is a prototype feature, the set of :ref:`operations supported <supported operations>` is
|
||||
As this is a prototype feature, the set of {ref}`operations supported <supported operations>` is
|
||||
limited, but growing. We welcome issues, feature requests, and contributions.
|
||||
More information on contributing can be found
|
||||
`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
|
||||
[in this Readme](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md).
|
||||
|
||||
.. _construction:
|
||||
(construction)=
|
||||
## Construction
|
||||
|
||||
Construction
|
||||
++++++++++++
|
||||
|
||||
.. note::
|
||||
```{note}
|
||||
|
||||
There are two forms of nested tensors present within PyTorch, distinguished by layout as
|
||||
specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``.
|
||||
@ -42,6 +41,7 @@ Construction
|
||||
supports a single ragged dimension, it has better op coverage, receives active development, and
|
||||
integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to
|
||||
nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout.
|
||||
```
|
||||
|
||||
Construction is straightforward and involves passing a list of tensors to the
|
||||
``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout
|
||||
@ -49,6 +49,7 @@ Construction is straightforward and involves passing a list of tensors to the
|
||||
into a packed, contiguous block of memory according to the layout described in the `data_layout`_
|
||||
section below.
|
||||
|
||||
```
|
||||
>>> a, b = torch.arange(3), torch.arange(5) + 3
|
||||
>>> a
|
||||
tensor([0, 1, 2])
|
||||
@ -57,31 +58,36 @@ tensor([3, 4, 5, 6, 7])
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> print([component for component in nt])
|
||||
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
|
||||
```
|
||||
|
||||
Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary
|
||||
along a single dimension. If the dimensionalities of the input components don't match, the
|
||||
constructor throws an error.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128) # 2D tensor
|
||||
>>> b = torch.randn(2, 50, 128) # 3D tensor
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
...
|
||||
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
|
||||
```
|
||||
|
||||
During construction, dtype, device, and whether gradients are required can be chosen via the
|
||||
usual keyword arguments.
|
||||
|
||||
```
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
|
||||
>>> print([component for component in nt])
|
||||
[tensor([0., 1., 2.], device='cuda:0',
|
||||
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
|
||||
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
|
||||
```
|
||||
|
||||
``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed
|
||||
to the constructor. When this constructor is utilized, gradients will flow through the nested tensor
|
||||
back into the original components. Note that this constructor still copies the input components into
|
||||
a packed, contiguous block of memory.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(12, 512, requires_grad=True)
|
||||
>>> b = torch.randn(23, 512, requires_grad=True)
|
||||
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -102,6 +108,7 @@ tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]])
|
||||
```
|
||||
|
||||
The above functions all create contiguous NJTs, where a chunk of memory is allocated to store
|
||||
a packed form of the underlying components (see the `data_layout`_ section below for more
|
||||
@ -111,6 +118,7 @@ It is also possible to create a non-contiguous NJT view over a pre-existing dens
|
||||
with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool
|
||||
for accomplishing this.
|
||||
|
||||
```
|
||||
>>> padded = torch.randn(3, 5, 4)
|
||||
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
|
||||
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
|
||||
@ -118,26 +126,26 @@ for accomplishing this.
|
||||
torch.Size([3, j1, 4])
|
||||
>>> nt.is_contiguous()
|
||||
False
|
||||
```
|
||||
|
||||
Note that the nested tensor acts as a view over the original padded dense tensor, referencing the
|
||||
same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more
|
||||
limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT
|
||||
using ``contiguous()``.
|
||||
|
||||
.. _data_layout:
|
||||
|
||||
Data Layout and Shape
|
||||
+++++++++++++++++++++
|
||||
(data_layout)=
|
||||
## Data Layout and Shape
|
||||
|
||||
For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of
|
||||
memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged``
|
||||
layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets``
|
||||
component delineating batch item boundaries for the ragged dimension.
|
||||
|
||||
.. image:: _static/img/nested/njt_visual.png
|
||||

|
||||
|
||||
It's possible to directly access the underlying NJT components when necessary.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128) # text 1
|
||||
>>> b = torch.randn(32, 128) # text 2
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -145,19 +153,23 @@ It's possible to directly access the underlying NJT components when necessary.
|
||||
torch.Size([82, 128])
|
||||
>>> nt.offsets()
|
||||
tensor([ 0, 50, 82])
|
||||
```
|
||||
|
||||
It can also be useful to construct an NJT from the jagged ``values`` and ``offsets``
|
||||
constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves
|
||||
this purpose.
|
||||
|
||||
```
|
||||
>>> values = torch.randn(82, 128)
|
||||
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
|
||||
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
|
||||
```
|
||||
|
||||
An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The
|
||||
underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the
|
||||
example below).
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -165,6 +177,7 @@ example below).
|
||||
3
|
||||
>>> nt.shape
|
||||
torch.Size([2, j1, 128])
|
||||
```
|
||||
|
||||
NJTs must have the same ragged structure to be compatible with each other. For example, to run a
|
||||
binary operation involving two NJTs, the ragged structures must match (i.e. they must have the
|
||||
@ -172,6 +185,7 @@ same ragged shape symbol in their shapes). In the details, each symbol correspon
|
||||
``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with
|
||||
each other.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -180,21 +194,20 @@ each other.
|
||||
False
|
||||
>>> nt3 = nt1 + nt2
|
||||
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
|
||||
```
|
||||
|
||||
In the above example, even though the conceptual shapes of the two NJTs are the same, they don't
|
||||
share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not
|
||||
compatible. We recognize that this behavior is unintuitive and are working hard to relax this
|
||||
restriction for the beta release of nested tensors. For a workaround, see the
|
||||
:ref:`Troubleshooting <ragged_structure_incompatibility>` section of this document.
|
||||
{ref}`Troubleshooting <ragged_structure_incompatibility>` section of this document.
|
||||
|
||||
In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum
|
||||
sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA).
|
||||
There are currently no public APIs for accessing these, but this will change for the beta release.
|
||||
|
||||
.. _supported operations:
|
||||
|
||||
Supported Operations
|
||||
++++++++++++++++++++
|
||||
(supported operations)=
|
||||
## Supported Operations
|
||||
|
||||
This section contains a list of common operations over nested tensors that you may find useful.
|
||||
It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While
|
||||
@ -203,15 +216,15 @@ The ideal state for nested tensors is full support of all PyTorch operations tha
|
||||
for non-nested tensors. To help us accomplish this, please consider:
|
||||
|
||||
* Requesting particular ops needed for your use case
|
||||
`here <https://github.com/pytorch/pytorch/issues/118107>`__ to help us prioritize.
|
||||
[here](https://github.com/pytorch/pytorch/issues/118107) to help us prioritize.
|
||||
* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see
|
||||
the `Contributions <contributions>`__ section below for details.
|
||||
the [Contributions](contributions) section below for details.
|
||||
|
||||
Viewing nested tensor constituents
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Viewing nested tensor constituents
|
||||
|
||||
``unbind()`` allows you to retrieve a view of the nested tensor's constituents.
|
||||
|
||||
```
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(3, 3)
|
||||
@ -231,16 +244,17 @@ tensor([[ 3.6858, -3.7030, -4.4525],
|
||||
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
|
||||
[ 2.0952, 0.2973, 0.2516],
|
||||
[ 0.9035, 1.3623, 0.2026]]))
|
||||
```
|
||||
|
||||
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which
|
||||
represents the first entry or constituent of the nested tensor.
|
||||
|
||||
Conversions to / from padded
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#### Conversions to / from padded
|
||||
|
||||
``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified
|
||||
padding value. The ragged dimension will be padded out to the size of the maximum sequence length.
|
||||
|
||||
```
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(6, 3)
|
||||
@ -259,6 +273,7 @@ tensor([[[ 1.6107, 0.5723, 0.3913],
|
||||
[-0.5506, 0.7608, 2.0606],
|
||||
[ 1.5658, -1.1934, 0.3041],
|
||||
[ 0.1483, -1.1284, 0.6957]]])
|
||||
```
|
||||
|
||||
This can be useful as an escape hatch to work around NJT support gaps, but ideally such
|
||||
conversions should be avoided when possible for optimal memory usage and performance, as the
|
||||
@ -269,6 +284,7 @@ ragged structure to a given dense tensor to produce an NJT. Note that by default
|
||||
does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be
|
||||
useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired.
|
||||
|
||||
```
|
||||
>>> padded = torch.randn(3, 5, 4)
|
||||
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
|
||||
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
|
||||
@ -277,12 +293,13 @@ torch.Size([3, j1, 4])
|
||||
>>> nt = nt.contiguous()
|
||||
>>> nt.shape
|
||||
torch.Size([3, j2, 4])
|
||||
```
|
||||
|
||||
Shape manipulations
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
### Shape manipulations
|
||||
|
||||
Nested tensors support a wide array of operations for shape manipulation, including views.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(2, 6)
|
||||
>>> b = torch.randn(4, 6)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
@ -298,36 +315,37 @@ torch.Size([2, j1, 12])
|
||||
torch.Size([2, j1, 2, 6])
|
||||
>>> nt.transpose(-1, -2).shape
|
||||
torch.Size([2, 6, j1])
|
||||
```
|
||||
|
||||
Attention mechanisms
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
### Attention mechanisms
|
||||
|
||||
As variable-length sequences are common inputs to attention mechanisms, nested tensors support
|
||||
important attention operators
|
||||
`Scaled Dot Product Attention (SDPA) <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_ and
|
||||
`FlexAttention <https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention>`_.
|
||||
[Scaled Dot Product Attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and
|
||||
[FlexAttention](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention).
|
||||
See
|
||||
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention>`__
|
||||
[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention)
|
||||
for usage examples of NJT with SDPA and
|
||||
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt>`__
|
||||
[here](https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt)
|
||||
for usage examples of NJT with FlexAttention.
|
||||
|
||||
.. _usage_with_torch_compile:
|
||||
(usage_with_torch_compile)=
|
||||
## Usage with torch.compile
|
||||
|
||||
Usage with torch.compile
|
||||
++++++++++++++++++++++++
|
||||
|
||||
NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always
|
||||
recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and
|
||||
graph-break-free both when passed as inputs to a compiled function or module OR when
|
||||
instantiated in-line within the function.
|
||||
|
||||
.. note::
|
||||
```{note}
|
||||
If you're not able to utilize ``torch.compile()`` for your use case, performance and memory
|
||||
usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be
|
||||
the case. It is important that the tensors being operated on are large enough so the
|
||||
performance gains are not outweighed by the overhead of python tensor subclasses.
|
||||
```
|
||||
|
||||
```
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(4, 3)
|
||||
@ -344,11 +362,13 @@ torch.Size([2, j1, 3])
|
||||
>>> output2 = compiled_g(nt.values(), nt.offsets())
|
||||
>>> output2.shape
|
||||
torch.Size([2, j1, 3])
|
||||
```
|
||||
|
||||
Note that NJTs support
|
||||
`Dynamic Shapes <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_
|
||||
[Dynamic Shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html)
|
||||
to avoid unnecessary recompiles with changing ragged structure.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(4, 3)
|
||||
>>> c = torch.randn(5, 3)
|
||||
@ -360,44 +380,39 @@ to avoid unnecessary recompiles with changing ragged structure.
|
||||
>>> compiled_f = torch.compile(f, fullgraph=True)
|
||||
>>> output1 = compiled_f(nt1)
|
||||
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
|
||||
```
|
||||
|
||||
If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a
|
||||
PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may
|
||||
be some rough edges at this time.
|
||||
|
||||
.. _troubleshooting:
|
||||
|
||||
Troubleshooting
|
||||
+++++++++++++++
|
||||
(troubleshooting)=
|
||||
## Troubleshooting
|
||||
|
||||
This section contains common errors that you may run into when utilizing nested tensors, alongside
|
||||
the reason for these errors and suggestions for how to address them.
|
||||
|
||||
.. _unimplemented_op:
|
||||
|
||||
Unimplemented ops
|
||||
^^^^^^^^^^^^^^^^^
|
||||
(unimplemented_op)=
|
||||
### Unimplemented ops
|
||||
|
||||
This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it
|
||||
today given that there are a couple thousand ops within PyTorch.
|
||||
|
||||
::
|
||||
|
||||
```
|
||||
NotImplementedError: aten.view_as_real.default
|
||||
```
|
||||
|
||||
The error is straightforward; we haven't gotten around to adding op support for this particular op
|
||||
yet. If you'd like, you can `contribute <contributions>`__ an implementation yourself OR simply
|
||||
`request <https://github.com/pytorch/pytorch/issues/118107>`_ that we add support for this op
|
||||
yet. If you'd like, you can [contribute](contributions) an implementation yourself OR simply
|
||||
[request](https://github.com/pytorch/pytorch/issues/118107) that we add support for this op
|
||||
in a future PyTorch release.
|
||||
|
||||
.. _ragged_structure_incompatibility:
|
||||
|
||||
Ragged structure incompatibility
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
(ragged_structure_incompatibility)=
|
||||
### Ragged structure incompatibility
|
||||
|
||||
```
|
||||
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
|
||||
```
|
||||
|
||||
This error occurs when calling an op that operates over multiple NJTs with incompatible ragged
|
||||
structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent
|
||||
@ -407,6 +422,7 @@ As a workaround for this situation, it is possible to construct NJTs from the ``
|
||||
``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they
|
||||
are considered to have the same ragged structure and are thus compatible.
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -414,18 +430,19 @@ are considered to have the same ragged structure and are thus compatible.
|
||||
>>> nt3 = nt1 + nt2
|
||||
>>> nt3.shape
|
||||
torch.Size([2, j1, 128])
|
||||
```
|
||||
|
||||
Data dependent operation within torch.compile
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
### Data dependent operation within torch.compile
|
||||
|
||||
```
|
||||
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
|
||||
```
|
||||
|
||||
This error occurs when calling an op that does data-dependent operation within torch.compile; this
|
||||
commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the
|
||||
output shape. For example:
|
||||
|
||||
```
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
@ -433,30 +450,31 @@ output shape. For example:
|
||||
...
|
||||
>>> compiled_f = torch.compile(f, fullgraph=True)
|
||||
>>> output = compiled_f(nt)
|
||||
```
|
||||
|
||||
In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the
|
||||
NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a
|
||||
workaround, there are a couple torch.compile flags that can be set:
|
||||
|
||||
```
|
||||
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
>>> torch._dynamo.config.capture_scalar_outputs = True
|
||||
```
|
||||
|
||||
If, after setting these, you still see data-dependent operator errors, please file an issue with
|
||||
PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of
|
||||
NJT support may be incomplete.
|
||||
|
||||
.. _contributions:
|
||||
|
||||
Contributions
|
||||
+++++++++++++
|
||||
(contributions)=
|
||||
## Contributions
|
||||
|
||||
If you'd like to contribute to nested tensor development, one of the most impactful ways to do
|
||||
so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally
|
||||
consists of a couple simple steps:
|
||||
|
||||
#. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
|
||||
1. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
|
||||
The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``.
|
||||
#. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
|
||||
2. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
|
||||
established there for other ops. Use the signature from ``native_functions.yaml`` for schema
|
||||
validation.
|
||||
|
||||
@ -474,20 +492,32 @@ working implementation:
|
||||
Within ``torch.compile``, these conversions can be fused to avoid materializing the padded
|
||||
intermediate.
|
||||
|
||||
.. _construction_and_conversion:
|
||||
|
||||
Detailed Docs for Construction and Conversion Functions
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
(construction_and_conversion)=
|
||||
|
||||
## Detailed Docs for Construction and Conversion Functions
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.nested
|
||||
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: nested_tensor
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: nested_tensor_from_jagged
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: as_nested_tensor
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: to_padded_tensor
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: masked_select
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: narrow
|
||||
|
||||
```
|
||||
```{eval-rst}
|
||||
.. seealso::
|
||||
|
||||
`Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile <https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
|
||||
```
|
@ -1,23 +1,25 @@
|
||||
```{eval-rst}
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
```
|
||||
# torch.nn.attention.bias
|
||||
|
||||
torch.nn.attention.bias
|
||||
========================
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.nn.attention.bias
|
||||
.. currentmodule:: torch.nn.attention.bias
|
||||
```
|
||||
|
||||
CausalBias
|
||||
==========
|
||||
## CausalBias
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classnoinheritance.rst
|
||||
|
||||
CausalBias
|
||||
|
||||
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
@ -25,3 +27,4 @@ CausalBias
|
||||
causal_lower_right
|
||||
causal_upper_left
|
||||
CausalVariant
|
||||
```
|
@ -1,7 +1,12 @@
|
||||
torch.nn.attention.experimental
|
||||
===============================
|
||||
.. currentmodule:: torch.nn.attention.experimental
|
||||
.. py:module:: torch.nn.attention.experimental
|
||||
# torch.nn.attention.experimental
|
||||
|
||||
.. warning::
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.nn.attention.experimental
|
||||
```
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.nn.attention.experimental
|
||||
```
|
||||
|
||||
```{warning}
|
||||
These APIs are experimental and subject to change without notice.
|
||||
```
|
@ -1,27 +1,45 @@
|
||||
```{eval-rst}
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
```
|
||||
|
||||
======================================
|
||||
torch.nn.attention.flex_attention
|
||||
======================================
|
||||
# torch.nn.attention.flex_attention
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.nn.attention.flex_attention
|
||||
```
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.nn.attention.flex_attention
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: flex_attention
|
||||
```
|
||||
|
||||
BlockMask Utilities
|
||||
-------------------
|
||||
## BlockMask Utilities
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: create_block_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: create_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: create_nested_block_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: and_masks
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: or_masks
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: noop_mask
|
||||
```
|
||||
|
||||
BlockMask
|
||||
---------
|
||||
## BlockMask
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: BlockMask
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
Reference in New Issue
Block a user