Update OSS nested tensor docs to focus on NJT (#145402)

Updated nested tensor docs to be NJT-centric (instead of NST-centric). They now include:
* High-level description of NST vs. NJT + a recommendation to use NJT
* General NJT construction / usage
* torch.compile() integration w/ dynamic shapes
* Common errors and how to fix them
* Contribution guide
* Data layout / shape information (with diagram)
* Links to more extensive tutorials involving Transformers / SDPA / FlexAttention

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145402
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser
2025-01-24 15:24:32 -05:00
committed by PyTorch MergeBot
parent 392dc177a9
commit b2a0feac85
3 changed files with 438 additions and 173 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@ -10,20 +10,23 @@ Introduction
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
Nested tensors allow for ragged-shaped data to be contained within and operated upon as a
single tensor. Such data is stored underneath in an efficient packed representation, while exposing
a standard PyTorch tensor interface for applying operations.
The only constraint on the input Tensors is that their dimension must match.
A common application of nested tensors is for expressing batches of variable-length sequential data
present in various domains, such as varying sentence lengths, image sizes, and audio / video clip
lengths. Traditionally, such data has been handled by padding sequences to that of the max length
within a batch, performing computation on the padded form, and subsequently masking to remove
padding. This is inefficient and error-prone, and nested tensors exist to address these problems.
This enables more efficient metadata representations and access to purpose built kernels.
The API for calling operations on a nested tensor is no different from that of a regular
``torch.Tensor``, allowing for seamless integration with existing models, with the main
difference being :ref:`construction of the inputs <construction>`.
One application of NestedTensors is to express sequential data in various domains.
While the conventional approach is to pad variable length sequences, NestedTensor
enables users to bypass padding. The API for calling operations on a nested tensor is no different
from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models,
with the main difference being :ref:`construction of the inputs <construction>`.
As this is a prototype feature, the :ref:`operations supported <supported operations>` are still
limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found
As this is a prototype feature, the set of :ref:`operations supported <supported operations>` is
limited, but growing. We welcome issues, feature requests, and contributions.
More information on contributing can be found
`in this Readme <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/nested/README.md>`_.
.. _construction:
@ -31,192 +34,456 @@ limited. However, we welcome issues, feature requests and contributions. More in
Construction
++++++++++++
Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor``
constructor.
.. note::
There are two forms of nested tensors present within PyTorch, distinguished by layout as
specified during construction. Layout can be one of ``torch.strided`` or ``torch.jagged``.
We recommend utilizing the ``torch.jagged`` layout whenever possible. While it currently only
supports a single ragged dimension, it has better op coverage, receives active development, and
integrates well with ``torch.compile``. These docs adhere to this recommendation and refer to
nested tensors with the ``torch.jagged`` layout as "NJTs" for brevity throughout.
Construction is straightforward and involves passing a list of tensors to the
``torch.nested.nested_tensor`` constructor. A nested tensor with the ``torch.jagged`` layout
(AKA an "NJT") supports a single ragged dimension. This constructor will copy the input tensors
into a packed, contiguous block of memory according to the layout described in the `data_layout`_
section below.
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b])
>>> nt
nested_tensor([
tensor([0, 1, 2]),
tensor([3, 4, 5, 6, 7])
])
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> print([component for component in nt])
[tensor([0, 1, 2]), tensor([3, 4, 5, 6, 7])]
Data type, device and whether gradients are required can be chosen via the usual keyword arguments.
Each tensor in the list must have the same number of dimensions, but the shapes can otherwise vary
along a single dimension. If the dimensionalities of the input components don't match, the
constructor throws an error.
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
], device='cuda:0', requires_grad=True)
>>> a = torch.randn(50, 128) # 2D tensor
>>> b = torch.randn(2, 50, 128) # 3D tensor
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
...
RuntimeError: When constructing a nested tensor, all tensors in list must have the same dim
In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd
history from the tensors passed to the constructor. For more information, refer to the section on
:ref:`constructor functions`.
During construction, dtype, device, and whether gradients are required can be chosen via the
usual keyword arguments.
In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32, device="cuda", requires_grad=True)
>>> print([component for component in nt])
[tensor([0., 1., 2.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>), tensor([3., 4., 5., 6., 7.], device='cuda:0',
grad_fn=<UnbindBackwardAutogradNestedTensor0>)]
>>> a = torch.randn(3, 50, 70) # image 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.dim()
4
``torch.nested.as_nested_tensor`` can be used to preserve autograd history from the tensors passed
to the constructor. When this constructor is utilized, gradients will flow through the nested tensor
back into the original components. Note that this constructor still copies the input components into
a packed, contiguous block of memory.
If one of the dimensions doesn't match, the constructor throws an error.
>>> a = torch.randn(12, 512, requires_grad=True)
>>> b = torch.randn(23, 512, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.sum().backward()
>>> a.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
>>> b.grad
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
The above functions all create contiguous NJTs, where a chunk of memory is allocated to store
a packed form of the underlying components (see the `data_layout`_ section below for more
details).
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
NestedTensor allocates new memory to store them and does not keep a reference.
It is also possible to create a non-contiguous NJT view over a pre-existing dense tensor
with padding, avoiding the memory allocation and copying. ``torch.nested.narrow()`` is the tool
for accomplishing this.
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that
we can track it and plan accordingly.
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, start=0, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt.is_contiguous()
False
size
+++++++++++++++++++++++++
Note that the nested tensor acts as a view over the original padded dense tensor, referencing the
same memory without copying / allocation. Operation support for non-contiguous NJTs is somewhat more
limited, so if you run into support gaps, it's always possible to convert to a contiguous NJT
using ``contiguous()``.
Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular.
.. _data_layout:
Data Layout and Shape
+++++++++++++++++++++
For efficiency, nested tensors generally pack their tensor components into a contiguous chunk of
memory and maintain additional metadata to specify batch item boundaries. For the ``torch.jagged``
layout, the contiguous chunk of memory is stored in the ``values`` component, with the ``offsets``
component delineating batch item boundaries for the ragged dimension.
.. image:: _static/img/nested/njt_visual.png
It's possible to directly access the underlying NJT components when necessary.
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Given dimension 1 is irregular and does not have a size.
>>> nt.size(2)
128
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.values().shape # note the "packing" of the ragged dimension; no padding needed
torch.Size([82, 128])
>>> nt.offsets()
tensor([ 0, 50, 82])
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``.
It can also be useful to construct an NJT from the jagged ``values`` and ``offsets``
constituents directly; the ``torch.nested.nested_tensor_from_jagged()`` constructor serves
this purpose.
>>> a = torch.randn(20, 128) # text 1
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
20
>>> nt.size(2)
128
>>> torch.stack(nt.unbind()).size()
torch.Size([2, 20, 128])
>>> torch.stack([a, a]).size()
torch.Size([2, 20, 128])
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
True
>>> values = torch.randn(82, 128)
>>> offsets = torch.tensor([0, 50, 82], dtype=torch.int64)
>>> nt = torch.nested.nested_tensor_from_jagged(values=values, offsets=offsets)
In the future we might make it easier to detect this condition and convert seamlessly.
An NJT has a well-defined shape with dimensionality 1 greater than that of its components. The
underlying structure of the ragged dimension is represented by a symbolic value (``j1`` in the
example below).
Please open a feature request if you have a need for this (or any other related feature for that matter).
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt.dim()
3
>>> nt.shape
torch.Size([2, j1, 128])
unbind
+++++++++++++++++++++++++
NJTs must have the same ragged structure to be compatible with each other. For example, to run a
binary operation involving two NJTs, the ragged structures must match (i.e. they must have the
same ragged shape symbol in their shapes). In the details, each symbol corresponds with an exact
``offsets`` tensor, so both NJTs must have the same ``offsets`` tensor to be compatible with
each other.
``unbind`` allows you to retrieve a view of the constituents.
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt1.offsets() is nt2.offsets()
False
>>> nt3 = nt1 + nt2
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
In the above example, even though the conceptual shapes of the two NJTs are the same, they don't
share a reference to the same ``offsets`` tensor, so their shapes differ, and they are not
compatible. We recognize that this behavior is unintuitive and are working hard to relax this
restriction for the beta release of nested tensors. For a workaround, see the
:ref:`Troubleshooting <ragged_structure_incompatibility>` section of this document.
In addition to the ``offsets`` metadata, NJTs can also compute and cache the minimum and maximum
sequence lengths for its components, which can be useful for invoking particular kernels (e.g. SDPA).
There are currently no public APIs for accessing these, but this will change for the beta release.
.. _supported operations:
Supported Operations
++++++++++++++++++++
This section contains a list of common operations over nested tensors that you may find useful.
It is not comprehensive, as there are on the order of a couple thousand ops within PyTorch. While
a sizeable subset of these are supported for nested tensors today, full support is a large task.
The ideal state for nested tensors is full support of all PyTorch operations that are available
for non-nested tensors. To help us accomplish this, please consider:
* Requesting particular ops needed for your use case
`here <https://github.com/pytorch/pytorch/issues/118107>`__ to help us prioritize.
* Contributing! It's not too hard to add nested tensor support for a given PyTorch op; see
the `Contributions <contributions>`__ section below for details.
Viewing nested tensor constituents
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``unbind()`` allows you to retrieve a view of the nested tensor's constituents.
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt
nested_tensor([
tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
>>> b = torch.randn(3, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.unbind()
(tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
(tensor([[-0.9916, -0.3363, -0.2799],
[-2.3520, -0.5896, -0.4374]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]])
>>> nt
nested_tensor([
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
>>> nt.unbind()
(tensor([[-2.9747, -1.0089, -0.8396],
[-7.0561, -1.7688, -1.3122]]), tensor([[-2.0969, -1.0104, 1.4841],
[ 2.0952, 0.2973, 0.2516],
[ 0.9035, 1.3623, 0.2026]]))
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which
represents the first entry or constituent of the nested tensor.
.. _constructor functions:
Conversions to / from padded
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Nested tensor constructor and conversion functions
++++++++++++++++++++++++++++++++++++++++++++++++++
``torch.nested.to_padded_tensor()`` converts an NJT to a padded dense tensor with the specified
padding value. The ragged dimension will be padded out to the size of the maximum sequence length.
The following functions are related to nested tensors:
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(6, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> padded = torch.nested.to_padded_tensor(nt, padding=4.2)
>>> padded
tensor([[[ 1.6107, 0.5723, 0.3913],
[ 0.0700, -0.4954, 1.8663],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000],
[ 4.2000, 4.2000, 4.2000]],
[[-0.0479, -0.7610, -0.3484],
[ 1.1345, 1.0556, 0.3634],
[-1.7122, -0.5921, 0.0540],
[-0.5506, 0.7608, 2.0606],
[ 1.5658, -1.1934, 0.3041],
[ 0.1483, -1.1284, 0.6957]]])
This can be useful as an escape hatch to work around NJT support gaps, but ideally such
conversions should be avoided when possible for optimal memory usage and performance, as the
more efficient nested tensor layout does not materialize padding.
The reverse conversion can be accomplished using ``torch.nested.narrow()``, which applies
ragged structure to a given dense tensor to produce an NJT. Note that by default, this operation
does not copy the underlying data, and thus the output NJT is generally non-contiguous. It may be
useful to explicitly call ``contiguous()`` here if a contiguous NJT is desired.
>>> padded = torch.randn(3, 5, 4)
>>> seq_lens = torch.tensor([3, 2, 5], dtype=torch.int64)
>>> nt = torch.nested.narrow(padded, dim=1, length=seq_lens, layout=torch.jagged)
>>> nt.shape
torch.Size([3, j1, 4])
>>> nt = nt.contiguous()
>>> nt.shape
torch.Size([3, j2, 4])
Shape manipulations
^^^^^^^^^^^^^^^^^^^
Nested tensors support a wide array of operations for shape manipulation, including views.
>>> a = torch.randn(2, 6)
>>> b = torch.randn(4, 6)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt.shape
torch.Size([2, j1, 6])
>>> nt.unsqueeze(-1).shape
torch.Size([2, j1, 6, 1])
>>> nt.unflatten(-1, [2, 3]).shape
torch.Size([2, j1, 2, 3])
>>> torch.cat([nt, nt], dim=2).shape
torch.Size([2, j1, 12])
>>> torch.stack([nt, nt], dim=2).shape
torch.Size([2, j1, 2, 6])
>>> nt.transpose(-1, -2).shape
torch.Size([2, 6, j1])
Attention mechanisms
^^^^^^^^^^^^^^^^^^^^
As variable-length sequences are common inputs to attention mechanisms, nested tensors support
important attention operators
`Scaled Dot Product Attention (SDPA) <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_ and
`FlexAttention <https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention>`_.
See
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#multiheadattention>`__
for usage examples of NJT with SDPA and
`here <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html#flexattention-njt>`__
for usage examples of NJT with FlexAttention.
.. _usage_with_torch_compile:
Usage with torch.compile
++++++++++++++++++++++++
NJTs are designed to be used with ``torch.compile()`` for optimal performance, and we always
recommend utilizing ``torch.compile()`` with NJTs when possible. NJTs work out-of-the-box and
graph-break-free both when passed as inputs to a compiled function or module OR when
instantiated in-line within the function.
.. note::
If you're not able to utilize ``torch.compile()`` for your use case, performance and memory
usage may still benefit from the use of NJTs, but it's not as clear-cut whether this will be
the case. It is important that the tensors being operated on are large enough so the
performance gains are not outweighed by the overhead of python tensor subclasses.
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
>>> output.shape
torch.Size([2, j1, 3])
>>> def g(values, offsets): return torch.nested.nested_tensor_from_jagged(values, offsets) * 2.
...
>>> compiled_g = torch.compile(g, fullgraph=True)
>>> output2 = compiled_g(nt.values(), nt.offsets())
>>> output2.shape
torch.Size([2, j1, 3])
Note that NJTs support
`Dynamic Shapes <https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html>`_
to avoid unnecessary recompiles with changing ragged structure.
>>> a = torch.randn(2, 3)
>>> b = torch.randn(4, 3)
>>> c = torch.randn(5, 3)
>>> d = torch.randn(6, 3)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged)
>>> nt2 = torch.nested.nested_tensor([c, d], layout=torch.jagged)
>>> def f(x): return x.sin() + 1
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output1 = compiled_f(nt1)
>>> output2 = compiled_f(nt2) # NB: No recompile needed even though ragged structure differs
If you run into problems or arcane errors when utilizing NJT + ``torch.compile``, please file a
PyTorch issue. Full subclass support within ``torch.compile`` is a long-term effort and there may
be some rough edges at this time.
.. _troubleshooting:
Troubleshooting
+++++++++++++++
This section contains common errors that you may run into when utilizing nested tensors, alongside
the reason for these errors and suggestions for how to address them.
.. _unimplemented_op:
Unimplemented ops
^^^^^^^^^^^^^^^^^
This error is becoming rarer as nested tensor op support grows, but it's still possible to hit it
today given that there are a couple thousand ops within PyTorch.
::
NotImplementedError: aten.view_as_real.default
The error is straightforward; we haven't gotten around to adding op support for this particular op
yet. If you'd like, you can `contribute <contributions>`__ an implementation yourself OR simply
`request <https://github.com/pytorch/pytorch/issues/118107>`_ that we add support for this op
in a future PyTorch release.
.. _ragged_structure_incompatibility:
Ragged structure incompatibility
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
RuntimeError: cannot call binary pointwise function add.Tensor with inputs of shapes (2, j2, 128) and (2, j3, 128)
This error occurs when calling an op that operates over multiple NJTs with incompatible ragged
structures. Currently, it is required that input NJTs have the exact same ``offsets`` constituent
in order to have the same symbolic ragged structure symbol (e.g. ``j1``).
As a workaround for this situation, it is possible to construct NJTs from the ``values`` and
``offsets`` components directly. With both NJTs referencing the same ``offsets`` components, they
are considered to have the same ragged structure and are thus compatible.
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt1 = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> nt2 = torch.nested.nested_tensor_from_jagged(values=torch.randn(82, 128), offsets=nt1.offsets())
>>> nt3 = nt1 + nt2
>>> nt3.shape
torch.Size([2, j1, 128])
Data dependent operation within torch.compile
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
::
torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default; to enable, set torch._dynamo.config.capture_scalar_outputs = True
This error occurs when calling an op that does data-dependent operation within torch.compile; this
commonly occurs for ops that need to examine the values of the NJT's ``offsets`` to determine the
output shape. For example:
>>> a = torch.randn(50, 128)
>>> b = torch.randn(32, 128)
>>> nt = torch.nested.nested_tensor([a, b], layout=torch.jagged, dtype=torch.float32)
>>> def f(nt): return nt.chunk(2, dim=0)[0]
...
>>> compiled_f = torch.compile(f, fullgraph=True)
>>> output = compiled_f(nt)
In this example, calling ``chunk()`` on the batch dimension of the NJT requires examination of the
NJT's ``offsets`` data to delineate batch item boundaries within the packed ragged dimension. As a
workaround, there are a couple torch.compile flags that can be set:
>>> torch._dynamo.config.capture_dynamic_output_shape_ops = True
>>> torch._dynamo.config.capture_scalar_outputs = True
If, after setting these, you still see data-dependent operator errors, please file an issue with
PyTorch. This area of ``torch.compile()`` is still in heavy development and certain aspects of
NJT support may be incomplete.
.. _contributions:
Contributions
+++++++++++++
If you'd like to contribute to nested tensor development, one of the most impactful ways to do
so is to add nested tensor support for a currently-unsupported PyTorch op. This process generally
consists of a couple simple steps:
#. Determine the name of the op to add; this should be something like ``aten.view_as_real.default``.
The signature for this op can be found in ``aten/src/ATen/native/native_functions.yaml``.
#. Register an op implementation in ``torch/nested/_internal/ops.py``, following the pattern
established there for other ops. Use the signature from ``native_functions.yaml`` for schema
validation.
The most common way to implement an op is to unwrap the NJT into its constituents, redispatch the
op on the underlying ``values`` buffer, and propagate the relevant NJT metadata (including
``offsets``) to a new output NJT. If the output of the op is expected to have a different shape
from the input, new ``offsets``, etc. metadata must be computed.
When an op is applied over the batch or ragged dimension, these tricks can help quickly get a
working implementation:
* For *non-batchwise* operation, an ``unbind()``-based fallback should work.
* For operation on the ragged dimension, consider converting to padded dense with a properly-selected
padding value that won't negatively bias the output, running the op, and converting back to NJT.
Within ``torch.compile``, these conversions can be fused to avoid materializing the padded
intermediate.
.. _construction_and_conversion:
Detailed Docs for Construction and Conversion Functions
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
.. currentmodule:: torch.nested
.. autofunction:: nested_tensor
.. autofunction:: nested_tensor_from_jagged
.. autofunction:: as_nested_tensor
.. autofunction:: to_padded_tensor
.. _supported operations:
Supported operations
++++++++++++++++++++++++++
In this section, we summarize the operations that are currently supported on
NestedTensor and any constraints they have.
.. csv-table::
:header: "PyTorch operation", "Constraints"
:widths: 30, 55
:delim: ;
:func:`torch.matmul`; "Supports matrix multiplication between two (>= 3d) nested tensors where
the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size
(i.e. no broadcasting support for batch dimensions yet)."
:func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors."
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.Tensor.masked_fill`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
:func:`torch.silu`; "Behavior is the same as on regular tensors."
:func:`torch.abs`; "Behavior is the same as on regular tensors."
:func:`torch.sgn`; "Behavior is the same as on regular tensors."
:func:`torch.logical_not`; "Behavior is the same as on regular tensors."
:func:`torch.neg`; "Behavior is the same as on regular tensors."
:func:`torch.sub`; "Supports elementwise subtraction of two nested tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor."
:func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar."
:func:`torch.select`; "Supports selecting along all dimensions."
:func:`torch.clone`; "Behavior is the same as on regular tensors."
:func:`torch.detach`; "Behavior is the same as on regular tensors."
:func:`torch.unbind`; "Supports unbinding along ``dim=0`` only."
:func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed).
Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited.
In particular, the only valid size for a irregular dimension is ``-1``.
Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``."
:func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``."
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
:func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
:func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
:func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
:func:`torch.nn.LayerNorm`; "The ``normalized_shape`` argument is restricted to not extend into the irregular dimensions of the NestedTensor."
.. autofunction:: masked_select
.. autofunction:: narrow

View File

@ -477,25 +477,23 @@ def masked_select(tensor: Tensor, mask: Tensor) -> Tensor:
Example::
>>> tensor = torch.randn(3, 3)
>>> mask = torch.tensor(
... [[False, False, True], [True, False, True], [False, False, True]]
... )
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([3, j4])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([1, 2, 1])
>>> tensor = torch.randn(3, 3)
>>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([3, j4])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([1, 2, 1])
>>> tensor = torch.randn(6, 5)
>>> mask = torch.tensor([False])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([6, j5])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([0, 0, 0, 0, 0, 0])
>>> tensor = torch.randn(6, 5)
>>> mask = torch.tensor([False])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([6, j5])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([0, 0, 0, 0, 0, 0])
"""
if tensor.layout != torch.strided:
raise RuntimeError(