mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
Move outplace ops to ATen (#16788)
Summary: Based on https://github.com/pytorch/pytorch/pull/12413, with the following additional changes: - Inside `native_functions.yml` move those outplace operators right next to everyone's corresponding inplace operators for convenience of checking if they match when reviewing - `matches_jit_signature: True` for them - Add missing `scatter` with Scalar source - Add missing `masked_fill` and `index_fill` with Tensor source. - Add missing test for `scatter` with Scalar source - Add missing test for `masked_fill` and `index_fill` with Tensor source by checking the gradient w.r.t source - Add missing docs to `tensor.rst` Differential Revision: D14069925 Pulled By: ezyang fbshipit-source-id: bb3f0cb51cf6b756788dc4955667fead6e8796e5
This commit is contained in:
committed by
Facebook Github Bot
parent
5737c5259c
commit
4fcab92d6c
@ -392,8 +392,9 @@ class CAFFE2_API Tensor {
|
||||
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntArrayRef signal_sizes={}) const;
|
||||
Tensor index(TensorList indices) const;
|
||||
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
|
||||
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
|
||||
Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const;
|
||||
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
|
||||
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
|
||||
Tensor inverse() const;
|
||||
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
|
||||
bool is_distributed() const;
|
||||
@ -559,16 +560,25 @@ class CAFFE2_API Tensor {
|
||||
Tensor & set_();
|
||||
bool is_set_to(const Tensor & tensor) const;
|
||||
Tensor & masked_fill_(const Tensor & mask, Scalar value);
|
||||
Tensor masked_fill(const Tensor & mask, Scalar value) const;
|
||||
Tensor & masked_fill_(const Tensor & mask, const Tensor & value);
|
||||
Tensor masked_fill(const Tensor & mask, const Tensor & value) const;
|
||||
Tensor & masked_scatter_(const Tensor & mask, const Tensor & source);
|
||||
Tensor masked_scatter(const Tensor & mask, const Tensor & source) const;
|
||||
Tensor view(IntArrayRef size) const;
|
||||
Tensor & put_(const Tensor & index, const Tensor & source, bool accumulate=false);
|
||||
Tensor & index_add_(int64_t dim, const Tensor & index, const Tensor & source);
|
||||
Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const;
|
||||
Tensor & index_fill_(int64_t dim, const Tensor & index, Scalar value);
|
||||
Tensor index_fill(int64_t dim, const Tensor & index, Scalar value) const;
|
||||
Tensor & index_fill_(int64_t dim, const Tensor & index, const Tensor & value);
|
||||
Tensor index_fill(int64_t dim, const Tensor & index, const Tensor & value) const;
|
||||
Tensor & scatter_(int64_t dim, const Tensor & index, const Tensor & src);
|
||||
Tensor scatter(int64_t dim, const Tensor & index, const Tensor & src) const;
|
||||
Tensor & scatter_(int64_t dim, const Tensor & index, Scalar value);
|
||||
Tensor scatter(int64_t dim, const Tensor & index, Scalar value) const;
|
||||
Tensor & scatter_add_(int64_t dim, const Tensor & index, const Tensor & src);
|
||||
Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const;
|
||||
Tensor & lt_(Scalar other);
|
||||
Tensor & lt_(const Tensor & other);
|
||||
Tensor & gt_(Scalar other);
|
||||
|
||||
@ -310,12 +310,15 @@ inline Tensor Tensor::index(TensorList indices) const {
|
||||
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return type().index_copy_(*this, dim, index, source);
|
||||
}
|
||||
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
|
||||
return type().index_put(*this, indices, values, accumulate);
|
||||
inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const {
|
||||
return type().index_copy(*this, dim, index, source);
|
||||
}
|
||||
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
|
||||
return type().index_put_(*this, indices, values, accumulate);
|
||||
}
|
||||
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
|
||||
return type().index_put(*this, indices, values, accumulate);
|
||||
}
|
||||
inline Tensor Tensor::inverse() const {
|
||||
return type().inverse(*this);
|
||||
}
|
||||
@ -811,12 +814,21 @@ inline bool Tensor::is_set_to(const Tensor & tensor) const {
|
||||
inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) {
|
||||
return type().masked_fill_(*this, mask, value);
|
||||
}
|
||||
inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const {
|
||||
return type().masked_fill(*this, mask, value);
|
||||
}
|
||||
inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) {
|
||||
return type().masked_fill_(*this, mask, value);
|
||||
}
|
||||
inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const {
|
||||
return type().masked_fill(*this, mask, value);
|
||||
}
|
||||
inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) {
|
||||
return type().masked_scatter_(*this, mask, source);
|
||||
}
|
||||
inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const {
|
||||
return type().masked_scatter(*this, mask, source);
|
||||
}
|
||||
inline Tensor Tensor::view(IntArrayRef size) const {
|
||||
return type().view(*this, size);
|
||||
}
|
||||
@ -826,21 +838,39 @@ inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool a
|
||||
inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return type().index_add_(*this, dim, index, source);
|
||||
}
|
||||
inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const {
|
||||
return type().index_add(*this, dim, index, source);
|
||||
}
|
||||
inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) {
|
||||
return type().index_fill_(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const {
|
||||
return type().index_fill(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) {
|
||||
return type().index_fill_(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const {
|
||||
return type().index_fill(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) {
|
||||
return type().scatter_(*this, dim, index, src);
|
||||
}
|
||||
inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const {
|
||||
return type().scatter(*this, dim, index, src);
|
||||
}
|
||||
inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) {
|
||||
return type().scatter_(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const {
|
||||
return type().scatter(*this, dim, index, value);
|
||||
}
|
||||
inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) {
|
||||
return type().scatter_add_(*this, dim, index, src);
|
||||
}
|
||||
inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const {
|
||||
return type().scatter_add(*this, dim, index, src);
|
||||
}
|
||||
inline Tensor & Tensor::lt_(Scalar other) {
|
||||
return type().lt_(*this, other);
|
||||
}
|
||||
|
||||
@ -272,8 +272,9 @@ struct CAFFE2_API Type {
|
||||
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const = 0;
|
||||
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
|
||||
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
|
||||
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
|
||||
virtual Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
|
||||
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
|
||||
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
|
||||
virtual Tensor inverse(const Tensor & self) const = 0;
|
||||
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
|
||||
virtual bool is_distributed(const Tensor & self) const = 0;
|
||||
@ -439,16 +440,25 @@ struct CAFFE2_API Type {
|
||||
virtual Tensor & set_(Tensor & self) const = 0;
|
||||
virtual bool is_set_to(const Tensor & self, const Tensor & tensor) const = 0;
|
||||
virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, Scalar value) const = 0;
|
||||
virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar value) const = 0;
|
||||
virtual Tensor & masked_fill_(Tensor & self, const Tensor & mask, const Tensor & value) const = 0;
|
||||
virtual Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & value) const = 0;
|
||||
virtual Tensor & masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) const = 0;
|
||||
virtual Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) const = 0;
|
||||
virtual Tensor view(const Tensor & self, IntArrayRef size) const = 0;
|
||||
virtual Tensor & put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate) const = 0;
|
||||
virtual Tensor & index_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
|
||||
virtual Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
|
||||
virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
|
||||
virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
|
||||
virtual Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0;
|
||||
virtual Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & value) const = 0;
|
||||
virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
|
||||
virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
|
||||
virtual Tensor & scatter_(Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
|
||||
virtual Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar value) const = 0;
|
||||
virtual Tensor & scatter_add_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
|
||||
virtual Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & src) const = 0;
|
||||
virtual Tensor & lt_(Tensor & self, Scalar other) const = 0;
|
||||
virtual Tensor & lt_(Tensor & self, const Tensor & other) const = 0;
|
||||
virtual Tensor & gt_(Tensor & self, Scalar other) const = 0;
|
||||
|
||||
@ -498,4 +498,50 @@ Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Ten
|
||||
return at::legacy::th::_th_index_copy_(self, dim, index, source);
|
||||
}
|
||||
|
||||
Tensor index_copy(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return self.clone().index_copy_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor index_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return self.clone().index_add_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
|
||||
return self.clone().index_fill_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return self.clone().index_fill_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return self.clone().scatter_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor scatter(const Tensor & self, int64_t dim, const Tensor & index, Scalar source) {
|
||||
return self.clone().scatter_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor scatter_add(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return self.clone().scatter_add_(dim, index, source);
|
||||
}
|
||||
|
||||
Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
|
||||
Tensor _mask, _self;
|
||||
std::tie(_mask, _self) = expand_outplace(mask, self);
|
||||
return _self.clone().masked_scatter_(_mask, source);
|
||||
}
|
||||
|
||||
Tensor masked_fill(const Tensor & self, const Tensor & mask, Scalar source) {
|
||||
Tensor _mask, _self;
|
||||
std::tie(_mask, _self) = expand_outplace(mask, self);
|
||||
return _self.clone().masked_fill_(mask, source);
|
||||
}
|
||||
|
||||
Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) {
|
||||
Tensor _mask, _self;
|
||||
std::tie(_mask, _self) = expand_outplace(mask, self);
|
||||
return _self.clone().masked_fill_(mask, source);
|
||||
}
|
||||
|
||||
}} // at::native
|
||||
|
||||
@ -1107,10 +1107,11 @@
|
||||
variants: function, method
|
||||
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
|
||||
|
||||
- func: index_copy_(Tensor(a!) self, int dim, IndexTensor index, Tensor source) -> Tensor(a!)
|
||||
- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
|
||||
- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
@ -1118,6 +1119,10 @@
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function
|
||||
@ -3056,14 +3061,26 @@
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: masked_fill(Tensor self, Tensor mask, Scalar value) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: masked_fill_(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: masked_fill(Tensor self, Tensor mask, Tensor value) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: view(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
@ -3077,26 +3094,50 @@
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: index_fill_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: index_fill(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: index_fill_(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: index_fill(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: scatter_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: scatter(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: scatter_(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: scatter(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: lt_(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
@ -252,9 +252,13 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: half
|
||||
.. automethod:: histc
|
||||
.. automethod:: index_add_
|
||||
.. automethod:: index_add
|
||||
.. automethod:: index_copy_
|
||||
.. automethod:: index_copy
|
||||
.. automethod:: index_fill_
|
||||
.. automethod:: index_fill
|
||||
.. automethod:: index_put_
|
||||
.. automethod:: index_put
|
||||
.. automethod:: index_select
|
||||
.. automethod:: int
|
||||
.. automethod:: inverse
|
||||
@ -285,7 +289,9 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: lt_
|
||||
.. automethod:: map_
|
||||
.. automethod:: masked_scatter_
|
||||
.. automethod:: masked_scatter
|
||||
.. automethod:: masked_fill_
|
||||
.. automethod:: masked_fill
|
||||
.. automethod:: masked_select
|
||||
.. automethod:: matmul
|
||||
.. automethod:: matrix_power
|
||||
@ -346,7 +352,9 @@ view of a storage and defines numeric operations on it.
|
||||
.. automethod:: rsqrt
|
||||
.. automethod:: rsqrt_
|
||||
.. automethod:: scatter_
|
||||
.. automethod:: scatter
|
||||
.. automethod:: scatter_add_
|
||||
.. automethod:: scatter_add
|
||||
.. automethod:: select
|
||||
.. automethod:: set_
|
||||
.. automethod:: share_memory_
|
||||
|
||||
@ -728,7 +728,8 @@ def method_tests():
|
||||
('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]),
|
||||
('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
|
||||
('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
|
||||
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
|
||||
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]),
|
||||
('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]),
|
||||
('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
|
||||
('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
|
||||
('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
|
||||
@ -741,15 +742,17 @@ def method_tests():
|
||||
('masked_select', (M, M), (torch.tensor(1, dtype=torch.uint8),), 'scalar_broadcast_rhs'),
|
||||
('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'),
|
||||
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
|
||||
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'),
|
||||
# no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace
|
||||
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), ()), 'tensor'),
|
||||
('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
|
||||
('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'),
|
||||
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'),
|
||||
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)),
|
||||
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10), 'scalar'),
|
||||
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), ()),
|
||||
'scalar_variable'),
|
||||
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10),
|
||||
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10),
|
||||
'scalar_broadcast_rhs'),
|
||||
('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
|
||||
('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
|
||||
'broadcast_lhs'),
|
||||
('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
|
||||
'broadcast_rhs'),
|
||||
('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'),
|
||||
|
||||
@ -9165,7 +9165,7 @@ a")
|
||||
def test_builtin_error_messsage(self):
|
||||
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"):
|
||||
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def close_match(x):
|
||||
return x.masked_fill(True)
|
||||
|
||||
@ -3455,6 +3455,7 @@ class _TestTorchMixin(object):
|
||||
|
||||
for fn in fns:
|
||||
(dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
|
||||
full1d = cast(torch.randn(*dims_full).flatten().float())
|
||||
small = cast(torch.randn(*dims_small).float())
|
||||
large = cast(torch.randn(*dims_large).float())
|
||||
small_expanded = small.expand(*dims_full)
|
||||
@ -3471,8 +3472,7 @@ class _TestTorchMixin(object):
|
||||
# map and map2 are not implementd on CUDA tensors
|
||||
continue
|
||||
|
||||
# TODO: fix masked_scatter and masked_fill broadcasting
|
||||
if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']:
|
||||
if hasattr(large_expanded, fn):
|
||||
# run through tensor versions of functions
|
||||
# and verify fully expanded inputs give same results
|
||||
expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
|
||||
@ -3482,6 +3482,10 @@ class _TestTorchMixin(object):
|
||||
return myfn(t1, 0.5)
|
||||
elif fn == "masked_select":
|
||||
return myfn(t1 < 0)
|
||||
elif fn == "masked_scatter":
|
||||
return myfn(t1 < 0.5, full1d)
|
||||
elif fn == "masked_fill":
|
||||
return myfn(t1 < 0.5, 1.0)
|
||||
elif fn in fns_3_args:
|
||||
return myfn(1, t1, t2)
|
||||
else:
|
||||
@ -3509,7 +3513,7 @@ class _TestTorchMixin(object):
|
||||
elif fn == "masked_select":
|
||||
return fntorch(t1, t2 < 0)
|
||||
elif fn == "masked_scatter":
|
||||
return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float()))
|
||||
return fntorch(t1, t2 < 0.5, full1d)
|
||||
elif fn == "masked_fill":
|
||||
return fntorch(t1, t2 < 0.5, 1.0)
|
||||
elif fn in fns_3_args:
|
||||
@ -3540,7 +3544,7 @@ class _TestTorchMixin(object):
|
||||
if fn == "lerp":
|
||||
return t0_fn(t1, 0.5)
|
||||
elif fn == "masked_scatter":
|
||||
return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float()))
|
||||
return t0_fn(t1 < 0.5, full1d)
|
||||
elif fn == "masked_fill":
|
||||
return t0_fn(t1 < 0.5, 1.0)
|
||||
elif fn == "map":
|
||||
|
||||
@ -80,13 +80,6 @@ class Tensor:
|
||||
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
|
||||
def split(self, split_size, dim=0): ...
|
||||
def index_add(self, dim, index, tensor): ...
|
||||
def index_copy(self, dim, index, tensor): ...
|
||||
def index_fill(self, dim, index, value): ...
|
||||
def scatter(self, dim, index, source): ...
|
||||
def scatter_add(self, dim, index, source): ...
|
||||
def masked_scatter(self, mask, tensor): ...
|
||||
def masked_fill(self, mask, value): ...
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None): ...
|
||||
|
||||
${function_hints}
|
||||
|
||||
@ -2963,6 +2963,55 @@ pinverse() -> Tensor
|
||||
See :func:`torch.pinverse`
|
||||
""")
|
||||
|
||||
add_docstr_all('index_add',
|
||||
r"""
|
||||
index_add(dim, index, tensor) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.index_add_`
|
||||
""")
|
||||
|
||||
add_docstr_all('index_copy',
|
||||
r"""
|
||||
index_copy(dim, index, tensor) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.index_copy_`
|
||||
""")
|
||||
|
||||
add_docstr_all('index_fill',
|
||||
r"""
|
||||
index_fill(dim, index, value) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.index_fill_`
|
||||
""")
|
||||
|
||||
add_docstr_all('scatter',
|
||||
r"""
|
||||
scatter(dim, index, source) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.scatter_`
|
||||
""")
|
||||
|
||||
add_docstr_all('scatter_add',
|
||||
r"""
|
||||
scatter_add(dim, index, source) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
|
||||
""")
|
||||
|
||||
add_docstr_all('masked_scatter',
|
||||
r"""
|
||||
masked_scatter(mask, tensor) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
|
||||
""")
|
||||
|
||||
add_docstr_all('masked_fill',
|
||||
r"""
|
||||
masked_fill(mask, value) -> Tensor
|
||||
|
||||
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
|
||||
""")
|
||||
|
||||
add_docstr_all('grad',
|
||||
r"""
|
||||
This attribute is ``None`` by default and becomes a Tensor the first time a call to
|
||||
|
||||
@ -307,41 +307,6 @@ class Tensor(torch._C._TensorBase):
|
||||
else:
|
||||
return super(Tensor, self).split_with_sizes(split_size, dim)
|
||||
|
||||
def index_add(self, dim, index, tensor):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.index_add_`
|
||||
"""
|
||||
return self.clone().index_add_(dim, index, tensor)
|
||||
|
||||
def index_copy(self, dim, index, tensor):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.index_copy_`
|
||||
"""
|
||||
return self.clone().index_copy_(dim, index, tensor)
|
||||
|
||||
def index_fill(self, dim, index, value):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.index_fill_`
|
||||
"""
|
||||
return self.clone().index_fill_(dim, index, value)
|
||||
|
||||
def scatter(self, dim, index, source):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.scatter_`
|
||||
"""
|
||||
return self.clone().scatter_(dim, index, source)
|
||||
|
||||
def scatter_add(self, dim, index, source):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.scatter_add_`
|
||||
"""
|
||||
return self.clone().scatter_add_(dim, index, source)
|
||||
|
||||
def masked_scatter(self, mask, tensor):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.masked_scatter_`
|
||||
"""
|
||||
return self.clone().masked_scatter_(mask, tensor)
|
||||
|
||||
def masked_fill(self, mask, value):
|
||||
r"""Out-of-place version of :meth:`torch.Tensor.masked_fill_`
|
||||
"""
|
||||
return self.clone().masked_fill_(mask, value)
|
||||
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None):
|
||||
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user