mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
add slice/select/diagonal_scatter variants as primitive ops (#64430)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64430 The functionalization pass needs `{view}_scatter` versions of the slice/select/diagonal ops in order to correctly propagate mutations from a view to its base. On top of that, the implementations need to be primitive w.r.t. autograd, because they look something like `...slice().copy_()`, and the functionalization pass can't use views + mutations inside of it's own alias-removal machinery! I added some basic tests that I tried to base off of existing tests for views (particularly around testing the derivative formulas), but I'm wondering if I should add something more comprehensive. Also, as_strided fits into this category - the functionalization pass will need an `as_strided_scatter` op that's primitive w.r.t. autograd. I didn't add it for now, because it'll involve duplicating a bunch of logic from the current `as_strided_backward()` function, and also writing a derivative formula that I wasn't sure how to write :) Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942092 Pulled By: bdhirsh fbshipit-source-id: c702a57c2748a7c771c14e4bcc3e996b48fcc4c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
665c148e42
commit
03f3a0331b
@ -1169,6 +1169,13 @@ diagonal(offset=0, dim1=0, dim2=1) -> Tensor
|
||||
See :func:`torch.diagonal`
|
||||
""")
|
||||
|
||||
add_docstr_all('diagonal_scatter',
|
||||
r"""
|
||||
diagonal(src, offset=0, dim1=0, dim2=1) -> Tensor
|
||||
|
||||
See :func:`torch.diagonal_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('fill_diagonal_',
|
||||
r"""
|
||||
fill_diagonal_(fill_value, wrap=False) -> Tensor
|
||||
@ -3352,18 +3359,21 @@ add_docstr_all('select',
|
||||
r"""
|
||||
select(dim, index) -> Tensor
|
||||
|
||||
Slices the :attr:`self` tensor along the selected dimension at the given index.
|
||||
This function returns a view of the original tensor with the given dimension removed.
|
||||
See :func:`torch.select`
|
||||
""")
|
||||
|
||||
Args:
|
||||
dim (int): the dimension to slice
|
||||
index (int): the index to select with
|
||||
add_docstr_all('select_scatter',
|
||||
r"""
|
||||
select_scatter(src, dim, index) -> Tensor
|
||||
|
||||
.. note::
|
||||
See :func:`torch.select_scatter`
|
||||
""")
|
||||
|
||||
:meth:`select` is equivalent to slicing. For example,
|
||||
``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and
|
||||
``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``.
|
||||
add_docstr_all('slice_scatter',
|
||||
r"""
|
||||
slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor
|
||||
|
||||
See :func:`torch.slice_scatter`
|
||||
""")
|
||||
|
||||
add_docstr_all('set_',
|
||||
|
Reference in New Issue
Block a user