Added indexing for bool tensors and bool Indices (#18583)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18583
ghimport-source-id: 2b1941449827f4ab632fa0f5c8cf0791a6be0845

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18583 Added indexing for bool tensors and bool Indices**
* #18505 Added numpy conversion
* #18166 Bool Tensor for CUDA

-----------
This PR enables bool tensor indexing and indexing with bool indices. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
    a) CPU [Done]
    b) CUDA [In review]
3. Tensor Conversions. [In review]
4. Tensor Indexing. [This PR]
5. Tensor Operations.
6. Back compatibility related changes.

TODO:
as a follow up, we should move nonzero method from TH to Aten to make code cleaner.

Change:
```
v = torch.tensor([True, False, True], dtype=torch.bool)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
v[boolIndices]
-> tensor([True], dtype=torch.bool)

v = torch.randn(5, 7, 3)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
v[boolIndices]
->
tensor([[[ 0.5885, -0.3322,  0.7388],
         [ 1.1182,  0.7808, -1.1492],
         [-0.7952,  0.5255, -0.0251],
         [ 0.7128,  0.8099,  1.2689],
         [-0.7018, -1.4733, -0.3732],
         [ 0.4503,  0.4986, -1.1605],
         [ 0.3348, -1.3767, -0.2976]],

        [[-2.0303, -0.4720, -0.1448],
         [-0.1914, -0.6821,  2.0061],
         [-1.0420, -0.1872, -0.3438],
         [ 1.7587, -0.4183, -0.7577],
         [ 1.0094, -0.1950, -0.2430],
         [ 0.1174,  0.3308, -0.5700],
         [ 0.1110, -0.2714,  1.3006]],

        [[-0.1946, -1.4747, -0.4650],
         [-1.0567,  1.0110, -0.2809],
         [ 0.3729, -0.5699,  0.0815],
         [-0.7733, -0.8316,  0.1674],
         [ 1.2000, -0.3745, -1.1679],
         [ 1.7105,  0.9851, -0.1907],
         [-1.1077,  0.2086, -0.0548]]])
```

Differential Revision: D14673403

fbshipit-source-id: 2b88ec2c7eb26a4f5ef64f8707fb68068d476fc9
This commit is contained in:
Iurii Zdebskyi
2019-04-03 10:53:11 -07:00
committed by Facebook Github Bot
parent 65dfe1203f
commit 5950c1e8c4
12 changed files with 133 additions and 83 deletions

View File

@ -111,6 +111,8 @@
[[ [[
name: _th_nonzero name: _th_nonzero
cname: nonzero cname: nonzero
cpu_bool: True
cuda_bool: True
variants: variants:
- function - function
return: argument 0 return: argument 0

View File

@ -5,7 +5,7 @@
// index(Tensor self, indices) -> Tensor // index(Tensor self, indices) -> Tensor
// index_put_(Tensor self, indices, value, accumulate=false) // index_put_(Tensor self, indices, value, accumulate=false)
// //
// The index is a TensorList containg kLong or kByte tensors or nulls. Byte // The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
// tensors signify that the dimension is not indexed. // tensors signify that the dimension is not indexed.
// //
@ -79,19 +79,19 @@ static void checkIndexTensorTypes(TensorList indices) {
for (auto& tensor : indices) { for (auto& tensor : indices) {
if (tensor.defined()) { if (tensor.defined()) {
auto scalarType = tensor.scalar_type(); auto scalarType = tensor.scalar_type();
if (scalarType != kLong && scalarType != kByte) { if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
AT_INDEX_ERROR("tensors used as indices must be long or byte tensors"); AT_INDEX_ERROR("tensors used as indices must be long, byte or bool tensors");
} }
} }
} }
} }
static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList indices) { static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
// Expands byte tensors (masks) into the equivalent indexing by LongTensors // Expands ByteTensor (masks) or BoolTensor (masks) into the equivalent indexing by LongTensors
std::vector<Tensor> result; std::vector<Tensor> result;
for (auto & index : indices) { for (auto & index : indices) {
if (index.scalar_type() == kByte) { if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
// The sizes of the ByteTensor mask must match the sizes of the // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self // corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) { for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = result.size() + j; int64_t srcIdx = result.size() + j;
@ -244,8 +244,8 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) { static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig); checkIndexTensorTypes(orig);
// first expand ByteTensor (boolean masks) into 1 or more LongTensors // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandByteTensors(self, orig); auto indices = expandTensors(self, orig);
// next broadcast all index tensors together // next broadcast all index tensors together
indices = expand_outplace(indices); indices = expand_outplace(indices);
// add missing null Tensors so that it matches self.dim() // add missing null Tensors so that it matches self.dim()
@ -378,8 +378,8 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
static AdvancedIndex make_info(Tensor self, TensorList orig) { static AdvancedIndex make_info(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig); checkIndexTensorTypes(orig);
// first expand ByteTensor (boolean masks) into 1 or more LongTensors // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandByteTensors(self, orig); auto indices = expandTensors(self, orig);
// next broadcast all index tensors together // next broadcast all index tensors together
try { try {
indices = expand_outplace(indices); indices = expand_outplace(indices);

View File

@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
} }
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cpu", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset); *(scalar_t*)dst = *(scalar_t*)(src + offset);
}); });
@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true. // NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
if (accumulate) { if (accumulate) {
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case, // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
// this needs to be thread-safe. // this needs to be thread-safe.

View File

@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
} }
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cuda", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>; using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride); index_kernel_impl<dtype>(iter, index_size, index_stride);
}); });
@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true"); AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>; using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride); index_put_kernel_impl<dtype>(iter, index_size, index_stride);
}); });

View File

@ -28,6 +28,9 @@
#include <TH/generic/THTensorMath.h> #include <TH/generic/THTensorMath.h>
#include <TH/THGenerateAllTypes.h> #include <TH/THGenerateAllTypes.h>
#include <TH/generic/THTensorMath.h>
#include <TH/THGenerateBoolType.h>
/* fill and zero*/ /* fill and zero*/
#include <TH/generic/THTensorFill.h> #include <TH/generic/THTensorFill.h>
#include <TH/THGenerateAllTypes.h> #include <TH/THGenerateAllTypes.h>

View File

@ -5,3 +5,6 @@
#include <TH/generic/THTensorEvenMoreMath.cpp> #include <TH/generic/THTensorEvenMoreMath.cpp>
#include <TH/THGenerateAllTypes.h> #include <TH/THGenerateAllTypes.h>
#include <TH/generic/THTensorEvenMoreMath.cpp>
#include <TH/THGenerateBoolType.h>

View File

@ -4,6 +4,71 @@
#include <TH/generic/THTensorApply.hpp> #include <TH/generic/THTensorApply.hpp>
// Finds non-zero elements of a tensor and returns their subscripts
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
{
ptrdiff_t numel = 0;
int64_t *subscript_data;
int64_t i = 0;
#ifdef TH_REAL_IS_HALF
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
#else
#define IS_NONZERO(val) ((val)!=0)
#endif
/* First Pass to determine size of subscripts */
TH_TENSOR_APPLY(scalar_t, tensor,
if IS_NONZERO(*tensor_data) {
++numel;
});
#ifdef DEBUG
THAssert(numel <= LONG_MAX);
#endif
THLongTensor_resize2d(subscript, numel, tensor->dim());
if (numel <= 0) {
return;
}
int64_t dimensions = tensor->dim();
// +1 faster than additional condition check inside loop
int64_t *sizes = new int64_t[dimensions+1];
int64_t *idx = new int64_t[dimensions+1];
int64_t *ii;
int64_t *ss;
std::fill(idx, idx+dimensions+1, 0);
for (i = 0; i < dimensions; ++i) {
sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
}
sizes[dimensions] = 0;
/* Second pass populates subscripts */
subscript_data = THLongTensor_data(subscript);
auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
subscript_strides[0] -= subscript_strides[1] * tensor->dim();
TH_TENSOR_APPLY(scalar_t, tensor,
if IS_NONZERO(*tensor_data) {
ii = idx + dimensions;
for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
--ii;
*subscript_data = *ii;
subscript_data += subscript_strides[1];
}
subscript_data += subscript_strides[0];
}
ii = idx;
ss = sizes;
++(*ii);
while (*ii == *ss) {
*ii = 0;
++ii;
++ss;
++(*ii);
}
);
delete [] sizes;
delete [] idx;
}
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value) void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
{ {
#ifdef _OPENMP #ifdef _OPENMP
@ -91,69 +156,6 @@ void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask
}); });
} }
// Finds non-zero elements of a tensor and returns their subscripts
void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
{
ptrdiff_t numel = 0;
int64_t *subscript_data;
int64_t i = 0;
#ifdef TH_REAL_IS_HALF
#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
#else
#define IS_NONZERO(val) ((val)!=0)
#endif
/* First Pass to determine size of subscripts */
TH_TENSOR_APPLY(scalar_t, tensor,
if IS_NONZERO(*tensor_data) {
++numel;
});
#ifdef DEBUG
THAssert(numel <= LONG_MAX);
#endif
THLongTensor_resize2d(subscript, numel, tensor->dim());
if (numel <= 0) {
return;
}
int64_t dimensions = tensor->dim();
// +1 faster than additional condition check inside loop
int64_t *sizes = new int64_t[dimensions+1];
int64_t *idx = new int64_t[dimensions+1];
int64_t *ii;
int64_t *ss;
std::fill(idx, idx+dimensions+1, 0);
for (i = 0; i < dimensions; ++i) {
sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
}
sizes[dimensions] = 0;
/* Second pass populates subscripts */
subscript_data = THLongTensor_data(subscript);
auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
subscript_strides[0] -= subscript_strides[1] * tensor->dim();
TH_TENSOR_APPLY(scalar_t, tensor,
if IS_NONZERO(*tensor_data) {
ii = idx + dimensions;
for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
--ii;
*subscript_data = *ii;
subscript_data += subscript_strides[1];
}
subscript_data += subscript_strides[0];
}
ii = idx;
ss = sizes;
++(*ii);
while (*ii == *ss) {
*ii = 0;
++ii;
++ss;
++(*ii);
}
);
delete [] sizes;
delete [] idx;
}
void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
{ {
ptrdiff_t i, numel; ptrdiff_t i, numel;
@ -959,4 +961,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
#endif #endif
} }
#endif
#endif /* TH_GENERIC_FILE */ #endif /* TH_GENERIC_FILE */

View File

@ -2,12 +2,14 @@
#define TH_GENERIC_FILE "TH/generic/THTensorMath.h" #define TH_GENERIC_FILE "TH/generic/THTensorMath.h"
#else #else
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value); TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value);
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src); TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask); TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask);
TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index); TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
@ -177,3 +179,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
#endif #endif
#endif #endif
#endif

View File

@ -109,6 +109,15 @@ struct NonZeroOp
} }
}; };
template <>
struct NonZeroOp<bool>
{
NonZeroOp() {}
__host__ __device__ bool operator()(bool lhs) const {
return lhs != false;
}
};
#include <THC/generic/THCTensorMath.cu> #include <THC/generic/THCTensorMath.cu>
#include <THC/THCGenerateAllTypes.h> #include <THC/THCGenerateAllTypes.h>

View File

@ -242,8 +242,6 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
} }
} }
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
THCTensor *self) THCTensor *self)
{ {
@ -318,6 +316,8 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){ void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_); int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_);

View File

@ -6,11 +6,11 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value);
THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
#if !defined(THC_REAL_IS_BOOL) /* non bool only part */ #if !defined(THC_REAL_IS_BOOL) /* non bool only part */
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self);

View File

@ -35,6 +35,32 @@ class TestIndexing(TestCase):
self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].sum(), 0) self.assertEqual(v[1:].sum(), 0)
def test_bool_indices(self):
v = torch.randn(5, 7, 3)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
self.assertEqual(v[boolIndices].shape, (3, 7, 3))
self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
v = torch.tensor([True, False, True], dtype=torch.bool)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8)
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
self.assertEqual(v[boolIndices], v[uint8Indices])
self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
def test_bool_indices_accumulate(self):
mask = torch.zeros(size=(10, ), dtype=torch.bool)
y = torch.ones(size=(10, 10))
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10)))
def test_multiple_bool_indices(self):
v = torch.randn(5, 7, 3)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool)
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool)
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask(self): def test_byte_mask(self):
v = torch.randn(5, 7, 3) v = torch.randn(5, 7, 3)
mask = torch.ByteTensor([1, 0, 1, 1, 0]) mask = torch.ByteTensor([1, 0, 1, 1, 0])