add slice/select/diagonal_scatter variants as primitive ops (#64430)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64430

The functionalization pass needs `{view}_scatter` versions of the slice/select/diagonal ops in order to correctly propagate mutations from a view to its base. On top of that, the implementations need to be primitive w.r.t. autograd, because they look something like `...slice().copy_()`, and the functionalization pass can't use views + mutations inside of it's own alias-removal machinery!

I added some basic tests that I tried to base off of existing tests for views (particularly around testing the derivative formulas), but I'm wondering if I should add something more comprehensive.

Also, as_strided fits into this category - the functionalization pass will need an `as_strided_scatter` op that's primitive w.r.t. autograd. I didn't add it for now, because it'll involve duplicating a bunch of logic from the current `as_strided_backward()` function, and also writing a derivative formula that I wasn't sure how to write :)

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31942092

Pulled By: bdhirsh

fbshipit-source-id: c702a57c2748a7c771c14e4bcc3e996b48fcc4c8
This commit is contained in:
Brian Hirsh
2021-10-28 10:43:11 -07:00
committed by Facebook GitHub Bot
parent 665c148e42
commit 03f3a0331b
7 changed files with 312 additions and 10 deletions

View File

@ -1169,6 +1169,13 @@ diagonal(offset=0, dim1=0, dim2=1) -> Tensor
See :func:`torch.diagonal`
""")
add_docstr_all('diagonal_scatter',
r"""
diagonal(src, offset=0, dim1=0, dim2=1) -> Tensor
See :func:`torch.diagonal_scatter`
""")
add_docstr_all('fill_diagonal_',
r"""
fill_diagonal_(fill_value, wrap=False) -> Tensor
@ -3352,18 +3359,21 @@ add_docstr_all('select',
r"""
select(dim, index) -> Tensor
Slices the :attr:`self` tensor along the selected dimension at the given index.
This function returns a view of the original tensor with the given dimension removed.
See :func:`torch.select`
""")
Args:
dim (int): the dimension to slice
index (int): the index to select with
add_docstr_all('select_scatter',
r"""
select_scatter(src, dim, index) -> Tensor
.. note::
See :func:`torch.select_scatter`
""")
:meth:`select` is equivalent to slicing. For example,
``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and
``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``.
add_docstr_all('slice_scatter',
r"""
slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor
See :func:`torch.slice_scatter`
""")
add_docstr_all('set_',