Move outplace ops to ATen (#16788)

Summary:
Based on https://github.com/pytorch/pytorch/pull/12413, with the following additional changes:

-  Inside `native_functions.yml` move those outplace operators right next to everyone's corresponding inplace operators for convenience of checking if they match when reviewing
- `matches_jit_signature: True` for them
- Add missing `scatter` with Scalar source
- Add missing `masked_fill` and `index_fill` with Tensor source.
- Add missing test for `scatter` with Scalar source
- Add missing test for `masked_fill` and `index_fill` with Tensor source by checking the gradient w.r.t source
- Add missing docs to `tensor.rst`

Differential Revision: D14069925

Pulled By: ezyang

fbshipit-source-id: bb3f0cb51cf6b756788dc4955667fead6e8796e5
This commit is contained in:
Xiang Gao
2019-02-15 15:54:50 -08:00
committed by Facebook Github Bot
parent 5737c5259c
commit 4fcab92d6c
12 changed files with 218 additions and 59 deletions

View File

@ -392,8 +392,9 @@ class CAFFE2_API Tensor {
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntArrayRef signal_sizes={}) const;
Tensor index(TensorList indices) const;
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const;
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
Tensor inverse() const;
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
bool is_distributed() const;
@ -559,16 +560,25 @@ class CAFFE2_API Tensor {
Tensor & set_();
bool is_set_to(const Tensor & tensor) const;
Tensor & masked_fill_(const Tensor & mask, Scalar value);
Tensor masked_fill(const Tensor & mask, Scalar value) const;
Tensor & masked_fill_(const Tensor & mask, const Tensor & value);
Tensor masked_fill(const Tensor & mask, const Tensor & value) const;
Tensor & masked_scatter_(const Tensor & mask, const Tensor & source);
Tensor masked_scatter(const Tensor & mask, const Tensor & source) const;
Tensor view(IntArrayRef size) const;
Tensor & put_(const Tensor & index, const Tensor & source, bool accumulate=false);
Tensor & index_add_(int64_t dim, const Tensor & index, const Tensor & source);
Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const;
Tensor & index_fill_(int64_t dim, const Tensor & index, Scalar value);
Tensor index_fill(int64_t dim, const Tensor & index, Scalar value) const;
Tensor & index_fill_(int64_t dim, const Tensor & index, const Tensor & value);
Tensor index_fill(int64_t dim, const Tensor & index, const Tensor & value) const;
Tensor & scatter_(int64_t dim, const Tensor & index, const Tensor & src);
Tensor scatter(int64_t dim, const Tensor & index, const Tensor & src) const;
Tensor & scatter_(int64_t dim, const Tensor & index, Scalar value);
Tensor scatter(int64_t dim, const Tensor & index, Scalar value) const;
Tensor & scatter_add_(int64_t dim, const Tensor & index, const Tensor & src);
Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const;
Tensor & lt_(Scalar other);
Tensor & lt_(const Tensor & other);
Tensor & gt_(Scalar other);

View File

@ -310,12 +310,15 @@ inline Tensor Tensor::index(TensorList indices) const {
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
return type().index_copy_(*this, dim, index, source);
}
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
return type().index_put(*this, indices, values, accumulate);
inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const {
return type().index_copy(*this, dim, index, source);
}
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
return type().index_put_(*this, indices, values, accumulate);
}
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
return type().index_put(*this, indices, values, accumulate);
}
inline Tensor Tensor::inverse() const {
return type().inverse(*this);
}
@ -811,12 +814,21 @@ inline bool Tensor::is_set_to(const Tensor & tensor) const {
inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) {
return type().masked_fill_(*this, mask, value);
}
inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const {
return type().masked_fill(*this, mask, value);
}
inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) {
return type().masked_fill_(*this, mask, value);
}
inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const {
return type().masked_fill(*this, mask, value);
}
inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) {
return type().masked_scatter_(*this, mask, source);
}
inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const {
return type().masked_scatter(*this, mask, source);
}
inline Tensor Tensor::view(IntArrayRef size) const {
return type().view(*this, size);
}
@ -826,21 +838,39 @@ inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool a
inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) {
return type().index_add_(*this, dim, index, source);
}
inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const {
return type().index_add(*this, dim, index, source);
}
inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) {
return type().index_fill_(*this, dim, index, value);
}
inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const {
return type().index_fill(*this, dim, index, value);
}
inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) {
return type().index_fill_(*this, dim, index, value);
}
inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const {
return type().index_fill(*this, dim, index, value);
}
inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) {
return type().scatter_(*this, dim, index, src);
}
inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const {
return type().scatter(*this, dim, index, src);
}
inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) {
return type().scatter_(*this, dim, index, value);
}
inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const {
return type().scatter(*this, dim, index, value);
}
inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) {
return type().scatter_add_(*this, dim, index, src);
}
inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const {
return type().scatter_add(*this, dim, index, src);
}
inline Tensor & Tensor::lt_(Scalar other) {
return type().lt_(*this, other);
}

View File

@ -272,8 +272,9 @@ struct CAFFE2_API Type {
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const = 0;
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor inverse(const Tensor & self) const = 0;
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
virtual bool is_distributed(const Tensor & self) const = 0;
@ -439,16 +440,25 @@ struct CAFFE2_API Type {
virtual Tensor & set_(Tensor & self) const = 0;
virtual bool is_set_to(const Tensor & self, const Tensor & tensor) const = 0;
virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, Scalar value) const = 0;
virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar value) const = 0;
virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, const Tensor & value) const = 0;
virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & value) const = 0;
virtual Tensor & masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) const = 0;
virtual Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) const = 0;
virtual Tensor view(const Tensor & self, IntArrayRef size) const = 0;
virtual Tensor & put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) const = 0;
virtual Tensor & index_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0;
virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0;
virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
virtual Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
virtual Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
virtual Tensor & lt_(Tensor & self, Scalar other) const = 0;
virtual Tensor & lt_(Tensor & self, const Tensor & other) const = 0;
virtual Tensor & gt_(Tensor & self, Scalar other) const = 0;

View File

@ -498,4 +498,50 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten
return at::legacy::th::_th_index_copy_(self, dim, index, source);
}
Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_copy_(dim, index, source);
}
Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_add_(dim, index, source);
}
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
return self.clone().index_fill_(dim, index, source);
}
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().index_fill_(dim, index, source);
}
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().scatter_(dim, index, source);
}
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
return self.clone().scatter_(dim, index, source);
}
Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
return self.clone().scatter_add_(dim, index, source);
}
Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_scatter_(_mask, source);
}
Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_fill_(mask, source);
}
Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) {
Tensor _mask, _self;
std::tie(_mask, _self) = expand_outplace(mask, self);
return _self.clone().masked_fill_(mask, source);
}
}} // at::native

View File

@ -1107,10 +1107,11 @@
variants: function, method
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
- func: index_copy_(Tensor(a!) self, int dim, IndexTensor index, Tensor source) -> Tensor(a!)
- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
matches_jit_signature: True
variants: function, method
@ -1118,6 +1119,10 @@
matches_jit_signature: True
variants: function, method
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
matches_jit_signature: True
variants: function, method
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
matches_jit_signature: True
variants: function
@ -3056,14 +3061,26 @@
matches_jit_signature: True
variants: method
- func: masked_fill(Tensor self, Tensor mask, Scalar value) -> Tensor
matches_jit_signature: True
variants: function, method
- func: masked_fill_(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: masked_fill(Tensor self, Tensor mask, Tensor value) -> Tensor
matches_jit_signature: True
variants: function, method
- func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
matches_jit_signature: True
variants: function, method
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
matches_jit_signature: True
variants: method
@ -3077,26 +3094,50 @@
matches_jit_signature: True
variants: method
- func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
matches_jit_signature: True
variants: function, method
- func: index_fill_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: index_fill(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
matches_jit_signature: True
variants: function, method
- func: index_fill_(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: index_fill(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
matches_jit_signature: True
variants: function, method
- func: scatter_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: scatter(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
matches_jit_signature: True
variants: function, method
- func: scatter_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: scatter(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
matches_jit_signature: True
variants: function, method
- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
matches_jit_signature: True
variants: method
- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
matches_jit_signature: True
variants: function, method
- func: lt_(Tensor(a!) self, Scalar other) -> Tensor(a!)
matches_jit_signature: True
variants: method

View File

@ -252,9 +252,13 @@ view of a storage and defines numeric operations on it.
.. automethod:: half
.. automethod:: histc
.. automethod:: index_add_
.. automethod:: index_add
.. automethod:: index_copy_
.. automethod:: index_copy
.. automethod:: index_fill_
.. automethod:: index_fill
.. automethod:: index_put_
.. automethod:: index_put
.. automethod:: index_select
.. automethod:: int
.. automethod:: inverse
@ -285,7 +289,9 @@ view of a storage and defines numeric operations on it.
.. automethod:: lt_
.. automethod:: map_
.. automethod:: masked_scatter_
.. automethod:: masked_scatter
.. automethod:: masked_fill_
.. automethod:: masked_fill
.. automethod:: masked_select
.. automethod:: matmul
.. automethod:: matrix_power
@ -346,7 +352,9 @@ view of a storage and defines numeric operations on it.
.. automethod:: rsqrt
.. automethod:: rsqrt_
.. automethod:: scatter_
.. automethod:: scatter
.. automethod:: scatter_add_
.. automethod:: scatter_add
.. automethod:: select
.. automethod:: set_
.. automethod:: share_memory_

View File

@ -728,7 +728,8 @@ def method_tests():
('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]),
('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]),
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]),
('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
@ -741,15 +742,17 @@ def method_tests():
('masked_select', (M, M), (torch.tensor(1, dtype=torch.uint8),), 'scalar_broadcast_rhs'),
('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'),
# no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), ()), 'tensor'),
('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10), 'scalar'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), ()),
'scalar_variable'),
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10),
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10),
'scalar_broadcast_rhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
'broadcast_lhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
'broadcast_rhs'),
('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'),

View File

@ -9165,7 +9165,7 @@ a")
def test_builtin_error_messsage(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"):
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)

View File

@ -3455,6 +3455,7 @@ class _TestTorchMixin(object):
for fn in fns:
(dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
full1d = cast(torch.randn(*dims_full).flatten().float())
small = cast(torch.randn(*dims_small).float())
large = cast(torch.randn(*dims_large).float())
small_expanded = small.expand(*dims_full)
@ -3471,8 +3472,7 @@ class _TestTorchMixin(object):
# map and map2 are not implementd on CUDA tensors
continue
# TODO: fix masked_scatter and masked_fill broadcasting
if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']:
if hasattr(large_expanded, fn):
# run through tensor versions of functions
# and verify fully expanded inputs give same results
expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
@ -3482,6 +3482,10 @@ class _TestTorchMixin(object):
return myfn(t1, 0.5)
elif fn == "masked_select":
return myfn(t1 < 0)
elif fn == "masked_scatter":
return myfn(t1 < 0.5, full1d)
elif fn == "masked_fill":
return myfn(t1 < 0.5, 1.0)
elif fn in fns_3_args:
return myfn(1, t1, t2)
else:
@ -3509,7 +3513,7 @@ class _TestTorchMixin(object):
elif fn == "masked_select":
return fntorch(t1, t2 < 0)
elif fn == "masked_scatter":
return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float()))
return fntorch(t1, t2 < 0.5, full1d)
elif fn == "masked_fill":
return fntorch(t1, t2 < 0.5, 1.0)
elif fn in fns_3_args:
@ -3540,7 +3544,7 @@ class _TestTorchMixin(object):
if fn == "lerp":
return t0_fn(t1, 0.5)
elif fn == "masked_scatter":
return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float()))
return t0_fn(t1 < 0.5, full1d)
elif fn == "masked_fill":
return t0_fn(t1 < 0.5, 1.0)
elif fn == "map":

View File

@ -80,13 +80,6 @@ class Tensor:
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0): ...
def index_add(self, dim, index, tensor): ...
def index_copy(self, dim, index, tensor): ...
def index_fill(self, dim, index, value): ...
def scatter(self, dim, index, source): ...
def scatter_add(self, dim, index, source): ...
def masked_scatter(self, mask, tensor): ...
def masked_fill(self, mask, value): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
${function_hints}

View File

@ -2963,6 +2963,55 @@ pinverse() -> Tensor
See :func:`torch.pinverse`
""")
add_docstr_all('index_add',
r"""
index_add(dim, index, tensor) -> Tensor
Out-of-place version of :meth:`torch.Tensor.index_add_`
""")
add_docstr_all('index_copy',
r"""
index_copy(dim, index, tensor) -> Tensor
Out-of-place version of :meth:`torch.Tensor.index_copy_`
""")
add_docstr_all('index_fill',
r"""
index_fill(dim, index, value) -> Tensor
Out-of-place version of :meth:`torch.Tensor.index_fill_`
""")
add_docstr_all('scatter',
r"""
scatter(dim, index, source) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_`
""")
add_docstr_all('scatter_add',
r"""
scatter_add(dim, index, source) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
""")
add_docstr_all('masked_scatter',
r"""
masked_scatter(mask, tensor) -> Tensor
Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
""")
add_docstr_all('masked_fill',
r"""
masked_fill(mask, value) -> Tensor
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
""")
add_docstr_all('grad',
r"""
This attribute is ``None`` by default and becomes a Tensor the first time a call to

View File

@ -307,41 +307,6 @@ class Tensor(torch._C._TensorBase):
else:
return super(Tensor, self).split_with_sizes(split_size, dim)
def index_add(self, dim, index, tensor):
r"""Out-of-place version of :meth:`torch.Tensor.index_add_`
"""
return self.clone().index_add_(dim, index, tensor)
def index_copy(self, dim, index, tensor):
r"""Out-of-place version of :meth:`torch.Tensor.index_copy_`
"""
return self.clone().index_copy_(dim, index, tensor)
def index_fill(self, dim, index, value):
r"""Out-of-place version of :meth:`torch.Tensor.index_fill_`
"""
return self.clone().index_fill_(dim, index, value)
def scatter(self, dim, index, source):
r"""Out-of-place version of :meth:`torch.Tensor.scatter_`
"""
return self.clone().scatter_(dim, index, source)
def scatter_add(self, dim, index, source):
r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_`
"""
return self.clone().scatter_add_(dim, index, source)
def masked_scatter(self, mask, tensor):
r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
"""
return self.clone().masked_scatter_(mask, tensor)
def masked_fill(self, mask, value):
r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_`
"""
return self.clone().masked_fill_(mask, value)
def unique(self, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.