Files
pytorch/docs/cpp/source/notes/tensor_indexing.rst
Will Feng (FAIAR) 86f3305859 Improve C++ API autograd and indexing docs (#35777)
Summary:
This PR adds docs for the following components:
1. Tensor autograd APIs (such as `is_leaf` / `backward` / `detach` / `detach_` / `retain_grad` / `grad` / `register_hook` / `remove_hook`)
2. Autograd APIs: `torch::autograd::backward` / `grad` / `Function` / `AutogradContext`, `torch::NoGradGuard` / `torch::AutoGradMode`
3. Tensor indexing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35777

Differential Revision: D20810616

Pulled By: yf225

fbshipit-source-id: 60526ec0c5b051021901d89bc3b56861c68758e8
2020-04-02 09:33:11 -07:00

100 lines
9.4 KiB
ReStructuredText

Tensor Indexing API
===================
Indexing a tensor in the PyTorch C++ API works very similar to the Python API.
All index types such as ``None`` / ``...`` / integer / boolean / slice / tensor
are available in the C++ API, making translation from Python indexing code to C++
very simple. The main difference is that, instead of using the ``[]``-operator
similar to the Python API syntax, in the C++ API the indexing methods are:
- ``torch::Tensor::index`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4NK2at6Tensor5indexE8ArrayRefIN2at8indexing11TensorIndexEE>`_)
- ``torch::Tensor::index_put_`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4N2at6Tensor10index_put_E8ArrayRefIN2at8indexing11TensorIndexEERK6Tensor>`_)
It's also important to note that index types such as ``None`` / ``Ellipsis`` / ``Slice``
live in the ``torch::indexing`` namespace, and it's recommended to put ``using namespace torch::indexing``
before any indexing code for convenient use of those index types.
Here are some examples of translating Python indexing code to C++:
Getter
------
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+==========================================================+======================================================================================+
| ``tensor[None]`` | ``tensor.index({None})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...]`` | ``tensor.index({Ellipsis, "..."})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2]`` | ``tensor.index({1, 2})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False]`` | ``tensor.index({true, false})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2]`` | ``tensor.index({Slice(1, None, 2)})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])]`` | ``tensor.index({torch::tensor({1, 2})})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])]`` | ``tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
Setter
------
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+==========================================================+======================================================================================+
| ``tensor[None] = 1`` | ``tensor.index_put_({None}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...] = 1`` | ``tensor.index_put_({Ellipsis, "..."}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2] = 1`` | ``tensor.index_put_({1, 2}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False] = 1`` | ``tensor.index_put_({true, false}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2] = 1`` | ``tensor.index_put_({Slice(1, None, 2)}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({torch::tensor({1, 2})}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
Translating between Python/C++ index types
------------------------------------------
The one-to-one translation between Python and C++ index types is as follows:
+-------------------------+------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+=========================+========================================================================+
| ``None`` | ``None`` |
+-------------------------+------------------------------------------------------------------------+
| ``Ellipsis`` | ``Ellipsis`` |
+-------------------------+------------------------------------------------------------------------+
| ``...`` | ``"..."`` |
+-------------------------+------------------------------------------------------------------------+
| ``123`` | ``123`` |
+-------------------------+------------------------------------------------------------------------+
| ``True`` | ``true`` |
+-------------------------+------------------------------------------------------------------------+
| ``False`` | ``false`` |
+-------------------------+------------------------------------------------------------------------+
| ``:`` or ``::`` | ``Slice()`` or ``Slice(None, None)`` or ``Slice(None, None, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:`` or ``1::`` | ``Slice(1, None)`` or ``Slice(1, None, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``:3`` or ``:3:`` | ``Slice(None, 3)`` or ``Slice(None, 3, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``::2`` | ``Slice(None, None, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:3`` | ``Slice(1, 3)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1::2`` | ``Slice(1, None, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``:3:2`` | ``Slice(None, 3, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:3:2`` | ``Slice(1, 3, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``torch.tensor([1, 2])``| ``torch::tensor({1, 2})`` |
+-------------------------+------------------------------------------------------------------------+