mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update OSS nested tensor docs to focus on NJT (#145402)
Updated nested tensor docs to be NJT-centric (instead of NST-centric). They now include: * High-level description of NST vs. NJT + a recommendation to use NJT * General NJT construction / usage * torch.compile() integration w/ dynamic shapes * Common errors and how to fix them * Contribution guide * Data layout / shape information (with diagram) * Links to more extensive tutorials involving Transformers / SDPA / FlexAttention Pull Request resolved: https://github.com/pytorch/pytorch/pull/145402 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
392dc177a9
commit
b2a0feac85
BIN
docs/source/_static/img/nested/njt_visual.png
Normal file
BIN
docs/source/_static/img/nested/njt_visual.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 54 KiB |
@ -10,20 +10,23 @@ Introduction
|
||||
|
||||
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
|
||||
|
||||
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
|
||||
Nested tensors allow for ragged-shaped data to be contained within and operated upon as a
|
||||
single tensor. Such data is stored underneath in an efficient packed representation, while exposing
|
||||
a standard PyTorch tensor interface for applying operations.
|
||||
|
||||
The only constraint on the input Tensors is that their dimension must match.
|
||||
A common application of nested tensors is for expressing batches of variable-length sequential data
|
||||
present in various domains, such as varying sentence lengths, image sizes, and audio / video clip
|
||||
lengths. Traditionally, such data has been handled by padding sequences to that of the max length
|
||||
within a batch, performing computation on the padded form, and subsequently masking to remove
|
||||
padding. This is inefficient and error-prone, and nested tensors exist to address these problems.
|
||||
|
||||
This enables more efficient metadata representations and access to purpose built kernels.
|
||||
The API for calling operations on a nested tensor is no different from that of a regular
|
||||
``torch.Tensor``, allowing for seamless integration with existing models, with the main
|
||||
difference being :ref:`construction of the inputs <construction>`.
|
||||
|
||||
One application of NestedTensors is to express sequential data in various domains.
|
||||
While the conventional approach is to pad variable length sequences, NestedTensor
|
||||
enables users to bypass padding. The API for calling operations on a nested tensor is no different
|
||||
from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models,
|
||||
with the main difference being :ref:`construction of the inputs <construction>`.
|
||||
|
||||
As this is a prototype feature, the :ref:`operations supported <supported operations>` are still
|
||||
limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found
|
||||
As this is a prototype feature, the set of :ref:`operations supported <supported operations>` is
|
||||
limited, but growing. We welcome issues, feature requests, and contributions.
|
||||
More information on contributing can be found
|
||||
`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
|
||||
|
||||
.. _construction:
|
||||
@ -31,192 +34,456 @@ limited. However, we welcome issues, feature requests and contributions. More in
|
||||
Construction
|
||||
++++++++++++
|
||||
|
||||
Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor``
|
||||
constructor.
|
||||
.. note::
|
||||
|
||||
There are two forms of nested tensors present within PyTorch, distinguished by layout as
|
||||
specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``.
|
||||
We recommend utilizing the ``torch.jagged`` layout whenever possible. While it currently only
|
||||
supports a single ragged dimension, it has better op coverage, receives active development, and
|
||||
integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to
|
||||
nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout.
|
||||
|
||||
Construction is straightforward and involves passing a list of tensors to the
|
||||
``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout
|
||||
(AKA an "NJT") supports a single ragged dimension. This constructor will copy the input tensors
|
||||
into a packed, contiguous block of memory according to the layout described in the `data_layout`_
|
||||
section below.
|
||||
|
||||
>>> a, b = torch.arange(3), torch.arange(5) + 3
|
||||
>>> a
|
||||
tensor([0, 1, 2])
|
||||
>>> b
|
||||
tensor([3, 4, 5, 6, 7])
|
||||
>>> nt = torch.nested.nested_tensor([a, b])
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0, 1, 2]),
|
||||
tensor([3, 4, 5, 6, 7])
|
||||
])
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> print([component for component in nt])
|
||||
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
|
||||
|
||||
Data type, device and whether gradients are required can be chosen via the usual keyword arguments.
|
||||
Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary
|
||||
along a single dimension. If the dimensionalities of the input components don't match, the
|
||||
constructor throws an error.
|
||||
|
||||
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
|
||||
tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
|
||||
], device='cuda:0', requires_grad=True)
|
||||
>>> a = torch.randn(50, 128) # 2D tensor
|
||||
>>> b = torch.randn(2, 50, 128) # 3D tensor
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
...
|
||||
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
|
||||
|
||||
In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd
|
||||
history from the tensors passed to the constructor. For more information, refer to the section on
|
||||
:ref:`constructor functions`.
|
||||
During construction, dtype, device, and whether gradients are required can be chosen via the
|
||||
usual keyword arguments.
|
||||
|
||||
In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
|
||||
>>> print([component for component in nt])
|
||||
[tensor([0., 1., 2.], device='cuda:0',
|
||||
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
|
||||
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
|
||||
|
||||
>>> a = torch.randn(3, 50, 70) # image 1
|
||||
>>> b = torch.randn(3, 128, 64) # image 2
|
||||
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt.dim()
|
||||
4
|
||||
``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed
|
||||
to the constructor. When this constructor is utilized, gradients will flow through the nested tensor
|
||||
back into the original components. Note that this constructor still copies the input components into
|
||||
a packed, contiguous block of memory.
|
||||
|
||||
If one of the dimensions doesn't match, the constructor throws an error.
|
||||
>>> a = torch.randn(12, 512, requires_grad=True)
|
||||
>>> b = torch.randn(23, 512, requires_grad=True)
|
||||
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt.sum().backward()
|
||||
>>> a.grad
|
||||
tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
...,
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]])
|
||||
>>> b.grad
|
||||
tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
...,
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]])
|
||||
|
||||
>>> a = torch.randn(50, 128) # text 1
|
||||
>>> b = torch.randn(3, 128, 64) # image 2
|
||||
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
|
||||
The above functions all create contiguous NJTs, where a chunk of memory is allocated to store
|
||||
a packed form of the underlying components (see the `data_layout`_ section below for more
|
||||
details).
|
||||
|
||||
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
|
||||
NestedTensor allocates new memory to store them and does not keep a reference.
|
||||
It is also possible to create a non-contiguous NJT view over a pre-existing dense tensor
|
||||
with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool
|
||||
for accomplishing this.
|
||||
|
||||
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
|
||||
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
|
||||
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
|
||||
has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that
|
||||
we can track it and plan accordingly.
|
||||
>>> padded = torch.randn(3, 5, 4)
|
||||
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
|
||||
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
|
||||
>>> nt.shape
|
||||
torch.Size([3, j1, 4])
|
||||
>>> nt.is_contiguous()
|
||||
False
|
||||
|
||||
size
|
||||
+++++++++++++++++++++++++
|
||||
Note that the nested tensor acts as a view over the original padded dense tensor, referencing the
|
||||
same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more
|
||||
limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT
|
||||
using ``contiguous()``.
|
||||
|
||||
Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular.
|
||||
.. _data_layout:
|
||||
|
||||
Data Layout and Shape
|
||||
+++++++++++++++++++++
|
||||
|
||||
For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of
|
||||
memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged``
|
||||
layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets``
|
||||
component delineating batch item boundaries for the ragged dimension.
|
||||
|
||||
.. image:: _static/img/nested/njt_visual.png
|
||||
|
||||
It's possible to directly access the underlying NJT components when necessary.
|
||||
|
||||
>>> a = torch.randn(50, 128) # text 1
|
||||
>>> b = torch.randn(32, 128) # text 2
|
||||
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt.size(0)
|
||||
2
|
||||
>>> nt.size(1)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
RuntimeError: Given dimension 1 is irregular and does not have a size.
|
||||
>>> nt.size(2)
|
||||
128
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed
|
||||
torch.Size([82, 128])
|
||||
>>> nt.offsets()
|
||||
tensor([ 0, 50, 82])
|
||||
|
||||
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``.
|
||||
It can also be useful to construct an NJT from the jagged ``values`` and ``offsets``
|
||||
constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves
|
||||
this purpose.
|
||||
|
||||
>>> a = torch.randn(20, 128) # text 1
|
||||
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
|
||||
>>> nt.size(0)
|
||||
2
|
||||
>>> nt.size(1)
|
||||
20
|
||||
>>> nt.size(2)
|
||||
128
|
||||
>>> torch.stack(nt.unbind()).size()
|
||||
torch.Size([2, 20, 128])
|
||||
>>> torch.stack([a, a]).size()
|
||||
torch.Size([2, 20, 128])
|
||||
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
|
||||
True
|
||||
>>> values = torch.randn(82, 128)
|
||||
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
|
||||
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
|
||||
|
||||
In the future we might make it easier to detect this condition and convert seamlessly.
|
||||
An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The
|
||||
underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the
|
||||
example below).
|
||||
|
||||
Please open a feature request if you have a need for this (or any other related feature for that matter).
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt.dim()
|
||||
3
|
||||
>>> nt.shape
|
||||
torch.Size([2, j1, 128])
|
||||
|
||||
unbind
|
||||
+++++++++++++++++++++++++
|
||||
NJTs must have the same ragged structure to be compatible with each other. For example, to run a
|
||||
binary operation involving two NJTs, the ragged structures must match (i.e. they must have the
|
||||
same ragged shape symbol in their shapes). In the details, each symbol corresponds with an exact
|
||||
``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with
|
||||
each other.
|
||||
|
||||
``unbind`` allows you to retrieve a view of the constituents.
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt1.offsets() is nt2.offsets()
|
||||
False
|
||||
>>> nt3 = nt1 + nt2
|
||||
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
|
||||
|
||||
In the above example, even though the conceptual shapes of the two NJTs are the same, they don't
|
||||
share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not
|
||||
compatible. We recognize that this behavior is unintuitive and are working hard to relax this
|
||||
restriction for the beta release of nested tensors. For a workaround, see the
|
||||
:ref:`Troubleshooting <ragged_structure_incompatibility>` section of this document.
|
||||
|
||||
In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum
|
||||
sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA).
|
||||
There are currently no public APIs for accessing these, but this will change for the beta release.
|
||||
|
||||
.. _supported operations:
|
||||
|
||||
Supported Operations
|
||||
++++++++++++++++++++
|
||||
|
||||
This section contains a list of common operations over nested tensors that you may find useful.
|
||||
It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While
|
||||
a sizeable subset of these are supported for nested tensors today, full support is a large task.
|
||||
The ideal state for nested tensors is full support of all PyTorch operations that are available
|
||||
for non-nested tensors. To help us accomplish this, please consider:
|
||||
|
||||
* Requesting particular ops needed for your use case
|
||||
`here <https://github.com/pytorch/pytorch/issues/118107>`__ to help us prioritize.
|
||||
* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see
|
||||
the `Contributions <contributions>`__ section below for details.
|
||||
|
||||
Viewing nested tensor constituents
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
``unbind()`` allows you to retrieve a view of the nested tensor's constituents.
|
||||
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(3, 4)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([[ 1.2286, -1.2343, -1.4842],
|
||||
[-0.7827, 0.6745, 0.0658]]),
|
||||
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
||||
])
|
||||
>>> b = torch.randn(3, 3)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> nt.unbind()
|
||||
(tensor([[ 1.2286, -1.2343, -1.4842],
|
||||
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
|
||||
(tensor([[-0.9916, -0.3363, -0.2799],
|
||||
[-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841],
|
||||
[ 2.0952, 0.2973, 0.2516],
|
||||
[ 0.9035, 1.3623, 0.2026]]))
|
||||
>>> nt.unbind()[0] is not a
|
||||
True
|
||||
>>> nt.unbind()[0].mul_(3)
|
||||
tensor([[ 3.6858, -3.7030, -4.4525],
|
||||
[-2.3481, 2.0236, 0.1975]])
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([[ 3.6858, -3.7030, -4.4525],
|
||||
[-2.3481, 2.0236, 0.1975]]),
|
||||
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
|
||||
[-0.2871, -0.2980, 0.5559, 1.9885],
|
||||
[ 0.4074, 2.4855, 0.0733, 0.8285]])
|
||||
])
|
||||
>>> nt.unbind()
|
||||
(tensor([[-2.9747, -1.0089, -0.8396],
|
||||
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
|
||||
[ 2.0952, 0.2973, 0.2516],
|
||||
[ 0.9035, 1.3623, 0.2026]]))
|
||||
|
||||
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
|
||||
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which
|
||||
represents the first entry or constituent of the nested tensor.
|
||||
|
||||
.. _constructor functions:
|
||||
Conversions to / from padded
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Nested tensor constructor and conversion functions
|
||||
++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified
|
||||
padding value. The ragged dimension will be padded out to the size of the maximum sequence length.
|
||||
|
||||
The following functions are related to nested tensors:
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(6, 3)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2)
|
||||
>>> padded
|
||||
tensor([[[ 1.6107, 0.5723, 0.3913],
|
||||
[ 0.0700, -0.4954, 1.8663],
|
||||
[ 4.2000, 4.2000, 4.2000],
|
||||
[ 4.2000, 4.2000, 4.2000],
|
||||
[ 4.2000, 4.2000, 4.2000],
|
||||
[ 4.2000, 4.2000, 4.2000]],
|
||||
[[-0.0479, -0.7610, -0.3484],
|
||||
[ 1.1345, 1.0556, 0.3634],
|
||||
[-1.7122, -0.5921, 0.0540],
|
||||
[-0.5506, 0.7608, 2.0606],
|
||||
[ 1.5658, -1.1934, 0.3041],
|
||||
[ 0.1483, -1.1284, 0.6957]]])
|
||||
|
||||
This can be useful as an escape hatch to work around NJT support gaps, but ideally such
|
||||
conversions should be avoided when possible for optimal memory usage and performance, as the
|
||||
more efficient nested tensor layout does not materialize padding.
|
||||
|
||||
The reverse conversion can be accomplished using ``torch.nested.narrow()``, which applies
|
||||
ragged structure to a given dense tensor to produce an NJT. Note that by default, this operation
|
||||
does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be
|
||||
useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired.
|
||||
|
||||
>>> padded = torch.randn(3, 5, 4)
|
||||
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
|
||||
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
|
||||
>>> nt.shape
|
||||
torch.Size([3, j1, 4])
|
||||
>>> nt = nt.contiguous()
|
||||
>>> nt.shape
|
||||
torch.Size([3, j2, 4])
|
||||
|
||||
Shape manipulations
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Nested tensors support a wide array of operations for shape manipulation, including views.
|
||||
|
||||
>>> a = torch.randn(2, 6)
|
||||
>>> b = torch.randn(4, 6)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> nt.shape
|
||||
torch.Size([2, j1, 6])
|
||||
>>> nt.unsqueeze(-1).shape
|
||||
torch.Size([2, j1, 6, 1])
|
||||
>>> nt.unflatten(-1, [2, 3]).shape
|
||||
torch.Size([2, j1, 2, 3])
|
||||
>>> torch.cat([nt, nt], dim=2).shape
|
||||
torch.Size([2, j1, 12])
|
||||
>>> torch.stack([nt, nt], dim=2).shape
|
||||
torch.Size([2, j1, 2, 6])
|
||||
>>> nt.transpose(-1, -2).shape
|
||||
torch.Size([2, 6, j1])
|
||||
|
||||
Attention mechanisms
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
As variable-length sequences are common inputs to attention mechanisms, nested tensors support
|
||||
important attention operators
|
||||
`Scaled Dot Product Attention (SDPA) <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_ and
|
||||
`FlexAttention <https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention>`_.
|
||||
See
|
||||
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention>`__
|
||||
for usage examples of NJT with SDPA and
|
||||
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt>`__
|
||||
for usage examples of NJT with FlexAttention.
|
||||
|
||||
.. _usage_with_torch_compile:
|
||||
|
||||
Usage with torch.compile
|
||||
++++++++++++++++++++++++
|
||||
|
||||
NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always
|
||||
recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and
|
||||
graph-break-free both when passed as inputs to a compiled function or module OR when
|
||||
instantiated in-line within the function.
|
||||
|
||||
.. note::
|
||||
If you're not able to utilize ``torch.compile()`` for your use case, performance and memory
|
||||
usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be
|
||||
the case. It is important that the tensors being operated on are large enough so the
|
||||
performance gains are not outweighed by the overhead of python tensor subclasses.
|
||||
|
||||
>>> import torch
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(4, 3)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> def f(x): return x.sin() + 1
|
||||
...
|
||||
>>> compiled_f = torch.compile(f, fullgraph=True)
|
||||
>>> output = compiled_f(nt)
|
||||
>>> output.shape
|
||||
torch.Size([2, j1, 3])
|
||||
>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2.
|
||||
...
|
||||
>>> compiled_g = torch.compile(g, fullgraph=True)
|
||||
>>> output2 = compiled_g(nt.values(), nt.offsets())
|
||||
>>> output2.shape
|
||||
torch.Size([2, j1, 3])
|
||||
|
||||
Note that NJTs support
|
||||
`Dynamic Shapes <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_
|
||||
to avoid unnecessary recompiles with changing ragged structure.
|
||||
|
||||
>>> a = torch.randn(2, 3)
|
||||
>>> b = torch.randn(4, 3)
|
||||
>>> c = torch.randn(5, 3)
|
||||
>>> d = torch.randn(6, 3)
|
||||
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged)
|
||||
>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged)
|
||||
>>> def f(x): return x.sin() + 1
|
||||
...
|
||||
>>> compiled_f = torch.compile(f, fullgraph=True)
|
||||
>>> output1 = compiled_f(nt1)
|
||||
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
|
||||
|
||||
If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a
|
||||
PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may
|
||||
be some rough edges at this time.
|
||||
|
||||
.. _troubleshooting:
|
||||
|
||||
Troubleshooting
|
||||
+++++++++++++++
|
||||
|
||||
This section contains common errors that you may run into when utilizing nested tensors, alongside
|
||||
the reason for these errors and suggestions for how to address them.
|
||||
|
||||
.. _unimplemented_op:
|
||||
|
||||
Unimplemented ops
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it
|
||||
today given that there are a couple thousand ops within PyTorch.
|
||||
|
||||
::
|
||||
|
||||
NotImplementedError: aten.view_as_real.default
|
||||
|
||||
The error is straightforward; we haven't gotten around to adding op support for this particular op
|
||||
yet. If you'd like, you can `contribute <contributions>`__ an implementation yourself OR simply
|
||||
`request <https://github.com/pytorch/pytorch/issues/118107>`_ that we add support for this op
|
||||
in a future PyTorch release.
|
||||
|
||||
.. _ragged_structure_incompatibility:
|
||||
|
||||
Ragged structure incompatibility
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
|
||||
|
||||
This error occurs when calling an op that operates over multiple NJTs with incompatible ragged
|
||||
structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent
|
||||
in order to have the same symbolic ragged structure symbol (e.g. ``j1``).
|
||||
|
||||
As a workaround for this situation, it is possible to construct NJTs from the ``values`` and
|
||||
``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they
|
||||
are considered to have the same ragged structure and are thus compatible.
|
||||
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets())
|
||||
>>> nt3 = nt1 + nt2
|
||||
>>> nt3.shape
|
||||
torch.Size([2, j1, 128])
|
||||
|
||||
Data dependent operation within torch.compile
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
This error occurs when calling an op that does data-dependent operation within torch.compile; this
|
||||
commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the
|
||||
output shape. For example:
|
||||
|
||||
>>> a = torch.randn(50, 128)
|
||||
>>> b = torch.randn(32, 128)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
|
||||
>>> def f(nt): return nt.chunk(2, dim=0)[0]
|
||||
...
|
||||
>>> compiled_f = torch.compile(f, fullgraph=True)
|
||||
>>> output = compiled_f(nt)
|
||||
|
||||
In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the
|
||||
NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a
|
||||
workaround, there are a couple torch.compile flags that can be set:
|
||||
|
||||
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
>>> torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
If, after setting these, you still see data-dependent operator errors, please file an issue with
|
||||
PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of
|
||||
NJT support may be incomplete.
|
||||
|
||||
.. _contributions:
|
||||
|
||||
Contributions
|
||||
+++++++++++++
|
||||
|
||||
If you'd like to contribute to nested tensor development, one of the most impactful ways to do
|
||||
so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally
|
||||
consists of a couple simple steps:
|
||||
|
||||
#. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
|
||||
The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``.
|
||||
#. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
|
||||
established there for other ops. Use the signature from ``native_functions.yaml`` for schema
|
||||
validation.
|
||||
|
||||
The most common way to implement an op is to unwrap the NJT into its constituents, redispatch the
|
||||
op on the underlying ``values`` buffer, and propagate the relevant NJT metadata (including
|
||||
``offsets``) to a new output NJT. If the output of the op is expected to have a different shape
|
||||
from the input, new ``offsets``, etc. metadata must be computed.
|
||||
|
||||
When an op is applied over the batch or ragged dimension, these tricks can help quickly get a
|
||||
working implementation:
|
||||
|
||||
* For *non-batchwise* operation, an ``unbind()``-based fallback should work.
|
||||
* For operation on the ragged dimension, consider converting to padded dense with a properly-selected
|
||||
padding value that won't negatively bias the output, running the op, and converting back to NJT.
|
||||
Within ``torch.compile``, these conversions can be fused to avoid materializing the padded
|
||||
intermediate.
|
||||
|
||||
.. _construction_and_conversion:
|
||||
|
||||
Detailed Docs for Construction and Conversion Functions
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
.. currentmodule:: torch.nested
|
||||
|
||||
.. autofunction:: nested_tensor
|
||||
.. autofunction:: nested_tensor_from_jagged
|
||||
.. autofunction:: as_nested_tensor
|
||||
.. autofunction:: to_padded_tensor
|
||||
|
||||
.. _supported operations:
|
||||
|
||||
Supported operations
|
||||
++++++++++++++++++++++++++
|
||||
|
||||
In this section, we summarize the operations that are currently supported on
|
||||
NestedTensor and any constraints they have.
|
||||
|
||||
.. csv-table::
|
||||
:header: "PyTorch operation", "Constraints"
|
||||
:widths: 30, 55
|
||||
:delim: ;
|
||||
|
||||
:func:`torch.matmul`; "Supports matrix multiplication between two (>= 3d) nested tensors where
|
||||
the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size
|
||||
(i.e. no broadcasting support for batch dimensions yet)."
|
||||
:func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors."
|
||||
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
|
||||
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
|
||||
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.relu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.silu`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.abs`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.sgn`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.logical_not`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.neg`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
|
||||
:func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor."
|
||||
:func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar."
|
||||
:func:`torch.select`; "Supports selecting along all dimensions."
|
||||
:func:`torch.clone`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.detach`; "Behavior is the same as on regular tensors."
|
||||
:func:`torch.unbind`; "Supports unbinding along ``dim=0`` only."
|
||||
:func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed).
|
||||
Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited.
|
||||
In particular, the only valid size for a irregular dimension is ``-1``.
|
||||
Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``."
|
||||
:func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``."
|
||||
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
|
||||
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
|
||||
:func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
|
||||
:func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
|
||||
:func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
|
||||
:func:`torch.nn.LayerNorm`; "The ``normalized_shape`` argument is restricted to not extend into the irregular dimensions of the NestedTensor."
|
||||
.. autofunction:: masked_select
|
||||
.. autofunction:: narrow
|
||||
|
@ -477,25 +477,23 @@ def masked_select(tensor: Tensor, mask: Tensor) -> Tensor:
|
||||
|
||||
Example::
|
||||
|
||||
>>> tensor = torch.randn(3, 3)
|
||||
>>> mask = torch.tensor(
|
||||
... [[False, False, True], [True, False, True], [False, False, True]]
|
||||
... )
|
||||
>>> nt = torch.nested.masked_select(tensor, mask)
|
||||
>>> nt.shape
|
||||
torch.Size([3, j4])
|
||||
>>> # Length of each item in the batch:
|
||||
>>> nt.offsets().diff()
|
||||
tensor([1, 2, 1])
|
||||
>>> tensor = torch.randn(3, 3)
|
||||
>>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
|
||||
>>> nt = torch.nested.masked_select(tensor, mask)
|
||||
>>> nt.shape
|
||||
torch.Size([3, j4])
|
||||
>>> # Length of each item in the batch:
|
||||
>>> nt.offsets().diff()
|
||||
tensor([1, 2, 1])
|
||||
|
||||
>>> tensor = torch.randn(6, 5)
|
||||
>>> mask = torch.tensor([False])
|
||||
>>> nt = torch.nested.masked_select(tensor, mask)
|
||||
>>> nt.shape
|
||||
torch.Size([6, j5])
|
||||
>>> # Length of each item in the batch:
|
||||
>>> nt.offsets().diff()
|
||||
tensor([0, 0, 0, 0, 0, 0])
|
||||
>>> tensor = torch.randn(6, 5)
|
||||
>>> mask = torch.tensor([False])
|
||||
>>> nt = torch.nested.masked_select(tensor, mask)
|
||||
>>> nt.shape
|
||||
torch.Size([6, j5])
|
||||
>>> # Length of each item in the batch:
|
||||
>>> nt.offsets().diff()
|
||||
tensor([0, 0, 0, 0, 0, 0])
|
||||
"""
|
||||
if tensor.layout != torch.strided:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user