mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Added indexing for bool tensors and bool Indices (#18583)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18583 ghimport-source-id: 2b1941449827f4ab632fa0f5c8cf0791a6be0845 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18583 Added indexing for bool tensors and bool Indices** * #18505 Added numpy conversion * #18166 Bool Tensor for CUDA ----------- This PR enables bool tensor indexing and indexing with bool indices. This is a part of Bool Tensor feature implementation work. The whole plan looks like this: 1. Storage Implementation [Done] 2. Tensor Creation. a) CPU [Done] b) CUDA [In review] 3. Tensor Conversions. [In review] 4. Tensor Indexing. [This PR] 5. Tensor Operations. 6. Back compatibility related changes. TODO: as a follow up, we should move nonzero method from TH to Aten to make code cleaner. Change: ``` v = torch.tensor([True, False, True], dtype=torch.bool) boolIndices = torch.tensor([True, False, False], dtype=torch.bool) v[boolIndices] -> tensor([True], dtype=torch.bool) v = torch.randn(5, 7, 3) boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool) v[boolIndices] -> tensor([[[ 0.5885, -0.3322, 0.7388], [ 1.1182, 0.7808, -1.1492], [-0.7952, 0.5255, -0.0251], [ 0.7128, 0.8099, 1.2689], [-0.7018, -1.4733, -0.3732], [ 0.4503, 0.4986, -1.1605], [ 0.3348, -1.3767, -0.2976]], [[-2.0303, -0.4720, -0.1448], [-0.1914, -0.6821, 2.0061], [-1.0420, -0.1872, -0.3438], [ 1.7587, -0.4183, -0.7577], [ 1.0094, -0.1950, -0.2430], [ 0.1174, 0.3308, -0.5700], [ 0.1110, -0.2714, 1.3006]], [[-0.1946, -1.4747, -0.4650], [-1.0567, 1.0110, -0.2809], [ 0.3729, -0.5699, 0.0815], [-0.7733, -0.8316, 0.1674], [ 1.2000, -0.3745, -1.1679], [ 1.7105, 0.9851, -0.1907], [-1.1077, 0.2086, -0.0548]]]) ``` Differential Revision: D14673403 fbshipit-source-id: 2b88ec2c7eb26a4f5ef64f8707fb68068d476fc9
This commit is contained in:
committed by
Facebook Github Bot
parent
65dfe1203f
commit
5950c1e8c4
@ -111,6 +111,8 @@
|
||||
[[
|
||||
name: _th_nonzero
|
||||
cname: nonzero
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
|
@ -5,7 +5,7 @@
|
||||
// index(Tensor self, indices) -> Tensor
|
||||
// index_put_(Tensor self, indices, value, accumulate=false)
|
||||
//
|
||||
// The index is a TensorList containg kLong or kByte tensors or nulls. Byte
|
||||
// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
|
||||
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null
|
||||
// tensors signify that the dimension is not indexed.
|
||||
//
|
||||
@ -79,19 +79,19 @@ static void checkIndexTensorTypes(TensorList indices) {
|
||||
for (auto& tensor : indices) {
|
||||
if (tensor.defined()) {
|
||||
auto scalarType = tensor.scalar_type();
|
||||
if (scalarType != kLong && scalarType != kByte) {
|
||||
AT_INDEX_ERROR("tensors used as indices must be long or byte tensors");
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
||||
AT_INDEX_ERROR("tensors used as indices must be long, byte or bool tensors");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList indices) {
|
||||
// Expands byte tensors (masks) into the equivalent indexing by LongTensors
|
||||
static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
|
||||
// Expands ByteTensor (masks) or BoolTensor (masks) into the equivalent indexing by LongTensors
|
||||
std::vector<Tensor> result;
|
||||
for (auto & index : indices) {
|
||||
if (index.scalar_type() == kByte) {
|
||||
// The sizes of the ByteTensor mask must match the sizes of the
|
||||
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
||||
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
||||
// corresponding dimensions in self
|
||||
for (int64_t j = 0; j < index.dim(); j++) {
|
||||
int64_t srcIdx = result.size() + j;
|
||||
@ -244,8 +244,8 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
|
||||
|
||||
static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
|
||||
checkIndexTensorTypes(orig);
|
||||
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
|
||||
auto indices = expandByteTensors(self, orig);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
// next broadcast all index tensors together
|
||||
indices = expand_outplace(indices);
|
||||
// add missing null Tensors so that it matches self.dim()
|
||||
@ -378,8 +378,8 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
|
||||
|
||||
static AdvancedIndex make_info(Tensor self, TensorList orig) {
|
||||
checkIndexTensorTypes(orig);
|
||||
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
|
||||
auto indices = expandByteTensors(self, orig);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
// next broadcast all index tensors together
|
||||
try {
|
||||
indices = expand_outplace(indices);
|
||||
|
@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
|
||||
}
|
||||
|
||||
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cpu", [&] {
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)dst = *(scalar_t*)(src + offset);
|
||||
});
|
||||
@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
|
||||
|
||||
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
|
||||
// NOTE: duplicate indices are only supported if accumulate is true.
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
|
||||
if (accumulate) {
|
||||
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
|
||||
// this needs to be thread-safe.
|
||||
|
@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
|
||||
}
|
||||
|
||||
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cuda", [&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
});
|
||||
@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
|
||||
|
||||
static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
|
||||
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
|
||||
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
});
|
||||
|
@ -28,6 +28,9 @@
|
||||
#include <TH/generic/THTensorMath.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
#include <TH/generic/THTensorMath.h>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
||||
/* fill and zero*/
|
||||
#include <TH/generic/THTensorFill.h>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
@ -5,3 +5,6 @@
|
||||
|
||||
#include <TH/generic/THTensorEvenMoreMath.cpp>
|
||||
#include <TH/THGenerateAllTypes.h>
|
||||
|
||||
#include <TH/generic/THTensorEvenMoreMath.cpp>
|
||||
#include <TH/THGenerateBoolType.h>
|
||||
|
@ -4,6 +4,71 @@
|
||||
|
||||
#include <TH/generic/THTensorApply.hpp>
|
||||
|
||||
// Finds non-zero elements of a tensor and returns their subscripts
|
||||
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
|
||||
{
|
||||
ptrdiff_t numel = 0;
|
||||
int64_t *subscript_data;
|
||||
int64_t i = 0;
|
||||
#ifdef TH_REAL_IS_HALF
|
||||
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
|
||||
#else
|
||||
#define IS_NONZERO(val) ((val)!=0)
|
||||
#endif
|
||||
|
||||
/* First Pass to determine size of subscripts */
|
||||
TH_TENSOR_APPLY(scalar_t, tensor,
|
||||
if IS_NONZERO(*tensor_data) {
|
||||
++numel;
|
||||
});
|
||||
#ifdef DEBUG
|
||||
THAssert(numel <= LONG_MAX);
|
||||
#endif
|
||||
THLongTensor_resize2d(subscript, numel, tensor->dim());
|
||||
if (numel <= 0) {
|
||||
return;
|
||||
}
|
||||
int64_t dimensions = tensor->dim();
|
||||
// +1 faster than additional condition check inside loop
|
||||
int64_t *sizes = new int64_t[dimensions+1];
|
||||
int64_t *idx = new int64_t[dimensions+1];
|
||||
int64_t *ii;
|
||||
int64_t *ss;
|
||||
std::fill(idx, idx+dimensions+1, 0);
|
||||
for (i = 0; i < dimensions; ++i) {
|
||||
sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
|
||||
}
|
||||
sizes[dimensions] = 0;
|
||||
/* Second pass populates subscripts */
|
||||
subscript_data = THLongTensor_data(subscript);
|
||||
auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
|
||||
subscript_strides[0] -= subscript_strides[1] * tensor->dim();
|
||||
TH_TENSOR_APPLY(scalar_t, tensor,
|
||||
if IS_NONZERO(*tensor_data) {
|
||||
ii = idx + dimensions;
|
||||
for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
|
||||
--ii;
|
||||
*subscript_data = *ii;
|
||||
subscript_data += subscript_strides[1];
|
||||
}
|
||||
subscript_data += subscript_strides[0];
|
||||
}
|
||||
ii = idx;
|
||||
ss = sizes;
|
||||
++(*ii);
|
||||
while (*ii == *ss) {
|
||||
*ii = 0;
|
||||
++ii;
|
||||
++ss;
|
||||
++(*ii);
|
||||
}
|
||||
);
|
||||
delete [] sizes;
|
||||
delete [] idx;
|
||||
}
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
|
||||
{
|
||||
#ifdef _OPENMP
|
||||
@ -91,69 +156,6 @@ void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask
|
||||
});
|
||||
}
|
||||
|
||||
// Finds non-zero elements of a tensor and returns their subscripts
|
||||
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
|
||||
{
|
||||
ptrdiff_t numel = 0;
|
||||
int64_t *subscript_data;
|
||||
int64_t i = 0;
|
||||
#ifdef TH_REAL_IS_HALF
|
||||
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
|
||||
#else
|
||||
#define IS_NONZERO(val) ((val)!=0)
|
||||
#endif
|
||||
|
||||
/* First Pass to determine size of subscripts */
|
||||
TH_TENSOR_APPLY(scalar_t, tensor,
|
||||
if IS_NONZERO(*tensor_data) {
|
||||
++numel;
|
||||
});
|
||||
#ifdef DEBUG
|
||||
THAssert(numel <= LONG_MAX);
|
||||
#endif
|
||||
THLongTensor_resize2d(subscript, numel, tensor->dim());
|
||||
if (numel <= 0) {
|
||||
return;
|
||||
}
|
||||
int64_t dimensions = tensor->dim();
|
||||
// +1 faster than additional condition check inside loop
|
||||
int64_t *sizes = new int64_t[dimensions+1];
|
||||
int64_t *idx = new int64_t[dimensions+1];
|
||||
int64_t *ii;
|
||||
int64_t *ss;
|
||||
std::fill(idx, idx+dimensions+1, 0);
|
||||
for (i = 0; i < dimensions; ++i) {
|
||||
sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
|
||||
}
|
||||
sizes[dimensions] = 0;
|
||||
/* Second pass populates subscripts */
|
||||
subscript_data = THLongTensor_data(subscript);
|
||||
auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
|
||||
subscript_strides[0] -= subscript_strides[1] * tensor->dim();
|
||||
TH_TENSOR_APPLY(scalar_t, tensor,
|
||||
if IS_NONZERO(*tensor_data) {
|
||||
ii = idx + dimensions;
|
||||
for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
|
||||
--ii;
|
||||
*subscript_data = *ii;
|
||||
subscript_data += subscript_strides[1];
|
||||
}
|
||||
subscript_data += subscript_strides[0];
|
||||
}
|
||||
ii = idx;
|
||||
ss = sizes;
|
||||
++(*ii);
|
||||
while (*ii == *ss) {
|
||||
*ii = 0;
|
||||
++ii;
|
||||
++ss;
|
||||
++(*ii);
|
||||
}
|
||||
);
|
||||
delete [] sizes;
|
||||
delete [] idx;
|
||||
}
|
||||
|
||||
void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
|
||||
{
|
||||
ptrdiff_t i, numel;
|
||||
@ -959,4 +961,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif /* TH_GENERIC_FILE */
|
||||
|
@ -2,12 +2,14 @@
|
||||
#define TH_GENERIC_FILE "TH/generic/THTensorMath.h"
|
||||
#else
|
||||
|
||||
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value);
|
||||
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
|
||||
TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask);
|
||||
|
||||
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
|
||||
|
||||
TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
|
||||
TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
|
||||
TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
|
||||
@ -177,3 +179,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
@ -109,6 +109,15 @@ struct NonZeroOp
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NonZeroOp<bool>
|
||||
{
|
||||
NonZeroOp() {}
|
||||
__host__ __device__ bool operator()(bool lhs) const {
|
||||
return lhs != false;
|
||||
}
|
||||
};
|
||||
|
||||
#include <THC/generic/THCTensorMath.cu>
|
||||
#include <THC/THCGenerateAllTypes.h>
|
||||
|
||||
|
@ -242,8 +242,6 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
|
||||
THCTensor *self)
|
||||
{
|
||||
@ -318,6 +316,8 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
|
||||
THCudaCheck(cudaGetLastError());
|
||||
}
|
||||
|
||||
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){
|
||||
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
|
||||
int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_);
|
||||
|
@ -6,11 +6,11 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value);
|
||||
THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
|
||||
THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
|
||||
THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
|
||||
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
|
||||
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
|
||||
|
||||
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
|
||||
THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
|
||||
THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
|
||||
THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self);
|
||||
|
@ -35,6 +35,32 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
|
||||
self.assertEqual(v[1:].sum(), 0)
|
||||
|
||||
def test_bool_indices(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
|
||||
self.assertEqual(v[boolIndices].shape, (3, 7, 3))
|
||||
self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
|
||||
|
||||
v = torch.tensor([True, False, True], dtype=torch.bool)
|
||||
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
|
||||
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8)
|
||||
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
|
||||
self.assertEqual(v[boolIndices], v[uint8Indices])
|
||||
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
|
||||
|
||||
def test_bool_indices_accumulate(self):
|
||||
mask = torch.zeros(size=(10, ), dtype=torch.bool)
|
||||
y = torch.ones(size=(10, 10))
|
||||
y.index_put_((mask, ), y[mask], accumulate=True)
|
||||
self.assertEqual(y, torch.ones(size=(10, 10)))
|
||||
|
||||
def test_multiple_bool_indices(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
# note: these broadcast together and are transposed to the first dim
|
||||
mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool)
|
||||
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool)
|
||||
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
|
||||
|
||||
def test_byte_mask(self):
|
||||
v = torch.randn(5, 7, 3)
|
||||
mask = torch.ByteTensor([1, 0, 1, 1, 0])
|
||||
|
Reference in New Issue
Block a user