Speed-up "advanced" indexing operations (#13420)

Summary:
This speeds-up "advanced" indexing (indexing a tensor by a tensor)
on CPU and GPU. There's still a bunch of work to do, including
speeding up indexing by a byte (boolean) mask and speeding up the derivative
calculation for advanced indexing.

Here's some speed comparisons to indexing on master using a little [benchmark script](https://gist.github.com/colesbury/c369db72aad594e5e032c8fda557d909) with 16 OpenMP threads and on a P100. The test cases are listed as (input shape -> output shape).

| Test case             | CPU (old vs. new)   | CUDA (old vs. new)     |
|-----------------------|---------------------|------------------------|
| 1024x1024 -> 512x1024 | 225 us vs. **57 us**  | 297 us vs. **47 us** |
| 1024x1024 -> 1024x512 | 208 us vs. **153 us** | 335 us vs. **54 us** |
| 50x50 -> 20000x50     | 617 us vs. **77 us**  | 239 us vs. **54 us** |
| 50x50 -> 50x20000     | 575 us vs. **236 us** | 262 us vs. **58 us** |
| 2x5x10 -> 10          | 65 us  vs. **18 us**  | 612 us vs. **93 us** |

See #11647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13420

Reviewed By: soumith

Differential Revision: D13088936

Pulled By: colesbury

fbshipit-source-id: 0a5c2ee9aa54e15f96d06692d1694c3b24b924e2
This commit is contained in:
Sam Gross
2018-11-27 15:18:39 -08:00
committed by Facebook Github Bot
parent 0199d59d3a
commit 006505bb8f
21 changed files with 584 additions and 99 deletions

View File

@ -356,8 +356,8 @@ public:
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const;
Tensor index(TensorList indices) const;
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
Tensor index_put(TensorList indices, const Tensor & values) const;
Tensor & index_put_(TensorList indices, const Tensor & values);
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
Tensor inverse() const;
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
bool is_distributed() const;

View File

@ -318,11 +318,11 @@ inline Tensor Tensor::index(TensorList indices) const {
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
return type().index_copy_(*this, dim, index, source);
}
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values) const {
return type().index_put(*this, indices, values);
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
return type().index_put(*this, indices, values, accumulate);
}
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values) {
return type().index_put_(*this, indices, values);
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
return type().index_put_(*this, indices, values, accumulate);
}
inline Tensor Tensor::inverse() const {
return type().inverse(*this);

View File

@ -264,8 +264,8 @@ struct CAFFE2_API Type {
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0;
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values) const = 0;
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values) const = 0;
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
virtual Tensor inverse(const Tensor & self) const = 0;
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
virtual bool is_distributed(const Tensor & self) const = 0;

View File

@ -17,6 +17,9 @@ struct OffsetCalculator {
using offset_type = at::cuda::Array<uint32_t, NARGS>;
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
if (dims > MAX_DIMS) {
throw std::runtime_error("tensor has too many (>25) dims");
}
for (int i = 0; i < MAX_DIMS; ++i) {
if (i < dims) {
sizes_[i] = IntDivider<uint32_t>(sizes[i]);

View File

@ -3,7 +3,7 @@
// This corresponds to "advanced indexing" in NumPy. The two operations are:
//
// index(Tensor self, indices) -> Tensor
// index_put_(Tensor self, indices, value)
// index_put_(Tensor self, indices, value, accumulate=false)
//
// The index is a TensorList containg kLong or kByte tensors or nulls. Byte
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null
@ -20,11 +20,40 @@
// Note 2: The behavior is more complicated when the index tensors are not all
// adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
// tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
//
// The code contains two implementations of indexing. The more efficient
// implementation treats indexing like an elementwise operation over the
// tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
// not work for index_put_ with accumulate=True. The other implementation
// combines the indexed tensors into a single linear index that is used
// with Tensor.put_. This is used for index_put_ with accumulate=True.
//
// The more efficient implementation takes the following steps for the
// above operation:
//
// 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
// 2) Record x.stride(i) for each indexed dimension `i`
// 3) Replace the indexed subspace of `x` with the shape of the corresponding
// subspace of `result` but with stride 0
// 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
// that their shape is compatible with the result shape
//
// The CPU or CUDA kernel then computes element-wise over the broadcasted
// and restrided result, x, ind_1, ind_2, etc.:
//
// result[...] = *(&x[...] +
// ind_1[...] * x.stride(1) +
// ind_2[...] * x.stride(2) +
// ...)
//
// where & and * represent the C-style address-of and indirection operations.
#include <ATen/native/Indexing.h>
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/ExpandUtils.h"
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <algorithm>
#include <functional>
@ -33,6 +62,9 @@
namespace at { namespace native {
DEFINE_DISPATCH(index_stub);
DEFINE_DISPATCH(index_put_stub);
[[noreturn]]
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
std::stringstream ss;
@ -226,34 +258,188 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
return std::make_tuple(self, linearIndex);
}
static bool all_strides_match(TensorList tensors) {
AT_ASSERT(tensors.size() >= 1);
auto strides = tensors[0].strides();
for (auto& tensor : tensors.slice(1)) {
if (!strides.equals(tensor.strides())) {
return false;
}
}
return true;
}
static std::string shapes_as_str(TensorList tensors) {
std::ostringstream os;
bool first = true;
for (auto& tensor : tensors) {
if (tensor.defined()) {
if (!first) {
os << ", ";
}
os << tensor.sizes();
first = false;
}
}
return os.str();
}
struct AdvancedIndex {
AdvancedIndex(const Tensor& src, TensorList indices);
Tensor src;
std::vector<Tensor> indices;
DimVector indexed_sizes;
DimVector indexed_strides;
int64_t dims_before;
int64_t dims_after;
};
// Replace indexed dimensions in src with stride 0 and the size of the result tensor.
// The offset in these dimensions is computed by the kernel using the index tensor's
// values and the stride of src. The new shape is not meaningful. It's used to make
// the shape compatible with the result tensor.
static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
IntList replacement_shape) {
auto shape = DimVector(src.sizes());
auto strides = DimVector(src.strides());
int end = dims_before + dims_indexed;
shape.erase(shape.begin() + dims_before, shape.begin() + end);
strides.erase(strides.begin() + dims_before, strides.begin() + end);
shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
return src.as_strided(shape, strides);
}
// Add dimensions of size 1 to an index tensor so that it's can be broadcast to the result
// shape and iterated over element-wise like the result tensor and the restrided src.
static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
auto orig_shape = index.sizes();
auto shape = DimVector();
shape.append(dims_before, 1);
shape.append(orig_shape.begin(), orig_shape.end());
shape.append(dims_after, 1);
return index.reshape(shape);
}
AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
{
int64_t element_size_bytes = src.type().elementSizeInBytes();
int dims_before = 0, dims_after = 0, dims_indexed = 0;
IntList replacement_shape;
for (size_t dim = 0; dim < indices_list.size(); dim++) {
if (!indices_list[dim].defined()) {
if (dims_indexed == 0) {
dims_before++;
} else {
dims_after++;
}
} else {
dims_indexed++;
replacement_shape = indices_list[dim].sizes();
indexed_sizes.push_back(src.size(dim));
indexed_strides.push_back(src.stride(dim) * element_size_bytes);
}
}
this->dims_before = dims_before;
this->dims_after = dims_after;
this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
for (auto& index : indices_list) {
if (index.defined()) {
indices.push_back(reshape_indexer(index, dims_before, dims_after));
}
}
// For CUDA tensors, force all index tensors to have the same striding to
// simplify the CUDA kernel.
if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
if (!all_strides_match(indices)) {
for (size_t i = 0; i < indices.size(); i++) {
indices[i] = indices[i].contiguous();
}
}
}
}
static AdvancedIndex make_info(Tensor self, TensorList orig) {
checkIndexTensorTypes(orig);
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
auto indices = expandByteTensors(self, orig);
// next broadcast all index tensors together
try {
indices = expand_outplace(indices);
} catch (std::exception& e) {
AT_ERROR("shape mismatch: indexing tensors could not be broadcast together"
" with shapes ", shapes_as_str(indices));
}
// add missing null Tensors so that it matches self.dim()
while (indices.size() < (size_t)self.dim()) {
indices.emplace_back();
}
// if the non-null indices are not all adjacent, transpose self and indices
// together so that they're adjacent at the front
if (!hasContiguousSubspace(indices)) {
std::tie(self, indices) = transposeToFront(self, indices);
}
return AdvancedIndex(self, indices);
}
static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
auto builder = TensorIterator::Builder();
builder.dont_compute_common_dtype();
builder.add_output(Tensor(), &info.src.type());
builder.add_input(info.src);
for (auto& index : info.indices) {
builder.add_input(index);
}
return builder.build();
}
static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
if (!is_expandable_to(value.sizes(), info.src.sizes())) {
AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(),
" cannot be broadcast to indexing result of shape ", info.src.sizes());
}
auto builder = TensorIterator::Builder();
builder.dont_compute_common_dtype();
builder.dont_resize_outputs();
builder.add_output(info.src);
builder.add_input(value, &info.src.type());
for (auto& index : info.indices) {
builder.add_input(index);
}
return builder.build();
}
Tensor index(const Tensor & self, TensorList indices) {
AT_CHECK(indices.size() <= (size_t)self.dim(),
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
Tensor src, linearIndex;
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
return src.take(linearIndex);
auto info = make_info(self, indices);
auto iter = make_index_iterator(info);
index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides);
return iter->output();
}
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value) {
AT_CHECK(indices.size() <= (size_t)self.dim(),
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
Tensor src, linearIndex, expandedValue;
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
std::tie(expandedValue) = expand_inplace(linearIndex, value);
Tensor dst = src.clone();
return dst.put_(linearIndex, expandedValue);
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
return self.clone().index_put_(indices, value, accumulate);
}
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
AT_CHECK(indices.size() <= (size_t)self.dim(),
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
Tensor src, linearIndex, expandedValue;
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
std::tie(expandedValue) = expand_inplace(linearIndex, value);
return src.put_(linearIndex, expandedValue);
if (accumulate && self.type().device_type() == kCUDA) {
Tensor src, linearIndex, expandedValue;
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
std::tie(expandedValue) = expand_inplace(linearIndex, value);
return src.put_(linearIndex, expandedValue, true);
}
auto info = make_info(self, indices);
auto iter = make_index_put_iterator(info, value);
index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate);
return self;
}
Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {

View File

@ -0,0 +1,20 @@
#pragma once
// Indexing tensors by by tensors
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
namespace at {
struct TensorIterator;
}
namespace at { namespace native {
using index_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides);
using index_put_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides, bool accumulate);
DECLARE_DISPATCH(index_fn, index_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
}} // namespace at::native

View File

@ -97,21 +97,47 @@ compute_result_type(at::ArrayRef<OperandInfo> operands, const F& predicate) {
return std::make_tuple(result_type, backend);
}
static bool needs_cast(const Tensor& tensor, const Type& dst_type) {
if (!tensor.defined() || dst_type == tensor.type()) {
return false;
void TensorIterator::compute_types() {
bool missing_dtypes = false;
for (auto& op : operands_) {
if (!op.tensor.defined() && !op.type) {
missing_dtypes = true;
}
}
if (dst_type.device_type() == DeviceType::CUDA &&
tensor.type().device_type() == DeviceType::CPU &&
tensor.dim() == 0) {
// zero-dim CPU tensors used in CUDA operations can be used directly without
// casting
return false;
if (missing_dtypes || compute_common_dtype_) {
auto& type = compute_common_type();
for (auto& op : operands_) {
if (!op.type) {
op.type = &type;
} else if (compute_common_dtype_ && op.type != &type) {
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
// don't cast CPU scalars in CUDA ops that directly support them
op.type = &op.tensor.type();
} else {
op.type = &type;
}
}
}
}
for (auto& op : operands_) {
if (op.tensor.defined() && op.tensor.type() != *op.type) {
if (op.is_output) {
AT_ERROR("output with type ", op.tensor.type().toString(),
" doesn't match the desired type ", type().toString());
} else if (op.tensor.dim() == 0) {
op.tensor = op.tensor.toType(*op.type);
} else {
AT_ERROR("expected type ", type().toString(), " but got ",
op.tensor.type().toString());
}
}
}
return true;
}
void TensorIterator::compute_common_type() {
Type& TensorIterator::compute_common_type() {
// See [Result type computation] in TensorIterator.h
auto result_type = ScalarType::Undefined;
auto backend = Backend::Undefined;
@ -132,18 +158,7 @@ void TensorIterator::compute_common_type() {
AT_ASSERT(result_type != ScalarType::Undefined);
AT_ASSERT(backend != Backend::Undefined);
auto& type = at::globalContext().getNonVariableType(backend, result_type);
for (auto& op : operands_) {
if (!op.type) {
op.type = &type;
op.needs_cast = needs_cast(op.tensor, type);
if (op.needs_cast && op.tensor.dim() == 0 && !op.is_output) {
op.tensor = op.tensor.toType(type);
op.needs_cast = false;
}
}
}
return at::globalContext().getNonVariableType(backend, result_type);
}
DimVector TensorIterator::compatible_stride(int element_size) const {
@ -171,6 +186,7 @@ void TensorIterator::allocate_outputs() {
for (int i = 0; i < num_outputs_; i++) {
auto& op = operands_[i];
if (!op.tensor.defined()) {
AT_ASSERTM(op.type, "no type for operand", i);
int element_size = op.type->elementSizeInBytes();
op.stride_bytes = compatible_stride(element_size);
@ -405,7 +421,7 @@ bool TensorIterator::is_scalar(int arg) const {
}
bool TensorIterator::is_cpu_scalar(int arg) const {
return is_scalar(arg) && operands_[arg].tensor.type().backend() == at::Backend::CPU;
return is_scalar(arg) && operands_[arg].tensor.type().device_type() == kCPU;
}
void* TensorIterator::data_ptr(int arg) const {
@ -450,6 +466,7 @@ std::unique_ptr<TensorIterator> TensorIterator::binary_op(Tensor& out, const Ten
builder.add_output(out);
builder.add_input(a);
builder.add_input(b);
builder.iter_->allow_cpu_scalars_ = true;
return builder.build();
}
@ -459,6 +476,7 @@ std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Ten
builder.add_output(out);
builder.add_input(a);
builder.iter_->resize_outputs_ = false;
builder.iter_->is_reduction_ = true;
return builder.build();
}
@ -485,7 +503,7 @@ void TensorIterator::compute_shape() {
// For now, don't include output tensors that are not also input tensors.
// This preserves the legacy behavior where torch.add(..., out=dst) resizes
// the destination tensor.
if (op.is_output && !op.is_read_write) continue;
if (resize_outputs_ && op.is_output && !op.is_read_write) continue;
auto shape = op.tensor.sizes();
if (shape_.empty()) {
@ -501,15 +519,17 @@ void TensorIterator::compute_shape() {
// outputs.
for (int i = 0; i < num_outputs_; i++) {
auto& tensor = operands_[i].tensor;
if (resize_outputs_ && tensor.defined() && !tensor.sizes().equals(shape_)) {
if (!operands_[i].is_read_write) {
if (tensor.defined() && !tensor.sizes().equals(shape_)) {
if (resize_outputs_ && !operands_[i].is_read_write) {
// Preserve legacy resizing behavior of out=... arguments
// TODO: issue warning
tensor.resize_(shape_);
continue;
}
AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
shape_);
if (!is_reduction_) {
AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
shape_);
}
}
}
}
@ -540,15 +560,6 @@ void TensorIterator::compute_strides() {
}
}
void TensorIterator::check_type_conversions() {
for (auto& op : operands_) {
if (op.needs_cast) {
AT_ERROR("TensorIterator expected type ", type().toString(), " but got ",
op.tensor.type().toString());
}
}
}
bool TensorIterator::can_use_32bit_indexing() const {
int64_t max_value = std::numeric_limits<int32_t>::max();
if (numel() > max_value) {
@ -612,7 +623,7 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
// re-order dimensions to improve coalescing
iter_->reorder_dimensions();
// compute the result dtype and backend
iter_->compute_common_type();
iter_->compute_types();
// allocate the output tensor if it's not provided
iter_->allocate_outputs();
// coalesce adjacent dimensions when possible
@ -623,8 +634,6 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
op.data = op.tensor.data_ptr();
}
iter_->check_type_conversions();
return std::move(iter_);
}

View File

@ -54,7 +54,12 @@ namespace at {
struct CAFFE2_API OperandInfo {
OperandInfo() {}
OperandInfo(const Tensor& t) : tensor(t) {}
OperandInfo(const Tensor& t, const Type* type=nullptr)
: tensor(t), type(const_cast<Type*>(type)) {
if (t.defined() && !type) {
this->type = &t.type();
}
}
/// Stride after broadcasting. The stride is in bytes, not number of elements.
DimVector stride_bytes;
@ -74,9 +79,6 @@ struct CAFFE2_API OperandInfo {
/// iterator is split.
void* data = nullptr;
/// True if the kernel needs to handle a cast operation for this operand.
bool needs_cast = false;
bool is_output = false;
bool is_read_write = false;
@ -210,10 +212,10 @@ protected:
void compute_strides();
void reorder_dimensions();
void permute_dimensions(IntList perm);
void compute_common_type();
void compute_types();
Type& compute_common_type();
void allocate_outputs();
void coalesce_dimensions();
void check_type_conversions();
protected:
DimVector shape_;
@ -223,6 +225,9 @@ protected:
bool has_coalesced_dimensions_ = false;
bool accumulate_ = false;
bool resize_outputs_ = true;
bool is_reduction_ = false;
bool compute_common_dtype_ = true;
bool allow_cpu_scalars_ = false;
};
struct TensorIterator::Builder {
@ -230,15 +235,21 @@ struct TensorIterator::Builder {
Builder() : iter_(new TensorIterator()) {};
Builder& add_output(const Tensor& output) {
iter_->operands_.emplace_back(output);
void add_output(const Tensor& output, const Type* type=nullptr) {
iter_->operands_.emplace_back(output, type);
iter_->num_outputs_++;
return *this;
}
Builder& add_input(const Tensor& input) {
iter_->operands_.emplace_back(input);
return *this;
void add_input(const Tensor& input, const Type* type=nullptr) {
iter_->operands_.emplace_back(input, type);
}
void dont_compute_common_dtype() {
iter_->compute_common_dtype_ = false;
}
void dont_resize_outputs() {
iter_->resize_outputs_ = false;
}
std::unique_ptr<TensorIterator> build();

View File

@ -0,0 +1,125 @@
#include <ATen/native/Indexing.h>
#include <cmath>
#include <iostream>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec256/vec256.h>
namespace at { namespace native {
namespace {
using namespace vec256;
struct Indexer {
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
IntList original_sizes, IntList original_strides)
: num_indexers(num_indexers)
, indexers(indexers)
, indexer_strides(indexer_strides)
, original_strides(original_strides.data())
, original_sizes(original_sizes.data()) {
AT_ASSERT(original_strides.size() == num_indexers);
AT_ASSERT(original_sizes.size() == num_indexers);
}
int64_t num_indexers;
char** indexers;
const int64_t* indexer_strides;
const int64_t* original_strides;
const int64_t* original_sizes;
int64_t get(int64_t idx) {
int64_t offset = 0;
for (int j = 0; j < num_indexers; j++) {
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
int64_t size = original_sizes[j];
if (value < -size || value >= size) {
AT_ERROR("index ", value, " is out of bounds for dim with size ", size);
}
if (value < 0) {
value += size;
}
offset += value * original_strides[j];
}
return offset;
}
};
static bool is_constant_index(int ntensor, const int64_t* strides) {
AT_ASSERT(ntensor >= 3);
for (int arg = 2; arg < ntensor; arg++) {
if (strides[arg] != 0) {
return false;
}
}
return true;
}
template <typename scalar_t, typename func_t>
void cpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride,
const func_t& f, bool serial_execution=false)
{
auto loop = [&](int ntensor, char** data, const int64_t* strides, int64_t n) {
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
char* dst = data[0];
char* src = data[1];
if (is_constant_index(ntensor, strides)) {
// specialization for when every element uses the same index
int64_t offset = indexer.get(0);
if (strides[0] == sizeof(scalar_t) && strides[1] == sizeof(scalar_t)) {
for (int64_t i = 0; i < n; i++) {
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
} else {
for (int64_t i = 0; i < n; i++) {
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
}
} else {
for (int64_t i = 0; i < n; i++) {
int64_t offset = indexer.get(i);
f(dst + strides[0] * i, src + strides[1] * i, offset);
}
}
};
if (serial_execution) {
iter.serial_for_each(loop, {0, iter.numel()});
} else {
iter.for_each(loop);
}
}
void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
AT_DISPATCH_ALL_TYPES(iter.type(0), "index", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
});
});
}
void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES(iter.type(0), "index_put", [&] {
if (accumulate) {
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
// this needs to be thread-safe.
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) += *(scalar_t*)src;
}, /*serial_execution=*/true);
} else {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
});
}
});
}
} // anonymous namespace
REGISTER_DISPATCH(index_stub, &index_kernel);
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
}} // namespace at::native

View File

@ -0,0 +1,102 @@
#include <ATen/native/Indexing.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/Array.h>
namespace at { namespace native {
template <int N>
static OffsetCalculator<N> index_make_offset_calculator(const TensorIterator& iter) {
AT_ASSERT(N <= iter.ntensors());
std::array<const int64_t*, N> strides;
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i).data();
}
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data());
}
template <typename func_t>
void gpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, const func_t& f) {
int num_indices = index_size.size();
AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(num_indices == iter.ntensors() - 2);
if (iter.numel() == 0) {
return;
}
auto sizes = cuda::Array<int64_t, 25>(0);
auto strides = cuda::Array<int64_t, 25>(0);
auto index_ptrs = cuda::Array<char*, 25>(nullptr);
for (int i = 0; i < num_indices; i++) {
sizes[i] = index_size[i];
strides[i] = index_stride[i];
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
}
char* out_ptr = (char*)iter.data_ptr(0);
char* in_ptr = (char*)iter.data_ptr(1);
auto offset_calc = index_make_offset_calculator<3>(iter);
launch_kernel<128, 4>(iter.numel(), [=]__device__(int idx) {
auto offsets = offset_calc.get(idx);
char* out_data = out_ptr + offsets[0];
char* in_data = in_ptr + offsets[1];
int64_t offset = 0;
#pragma unroll
for (int i = 0; i < num_indices; i++) {
int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
assert(index >= -sizes[i] && index < sizes[i] && "index out of bounds");
if (index < 0) {
index += sizes[i];
}
offset += index * strides[i];
}
f(out_data, in_data, offset);
});
}
// The kernels are templated on an opaque, self-aligned type of the correct
// size to avoid redundant kernels for different types of the same size.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <typename scalar_t>
void index_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
*(scalar_t*)out_data = *(scalar_t*)(in_data + offset);
});
}
template <typename scalar_t>
void index_put_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
*(scalar_t*)(out_data + offset) = *(scalar_t*)in_data;
});
}
static void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
});
}
static void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
});
}
REGISTER_DISPATCH(index_stub, &index_kernel);
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
}} // namespace at::native

View File

@ -888,10 +888,10 @@
- func: index_copy_(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
variants: method
- func: index_put(Tensor self, TensorList indices, Tensor values) -> Tensor
- func: index_put(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
variants: function, method
- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
- func: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
variants: function, method
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) -> Tensor

View File

@ -7,6 +7,7 @@ graph(%target : Double(100)
%6 : bool = prim::Constant[value=0]()
%indices : Long(4) = aten::_cast_Long(%indices.1, %6)
%8 : Dynamic[] = prim::ListConstruct(%indices)
%9 : Double(100) = aten::index_put_(%target, %8, %5)
return (%9);
%9 : bool = prim::Constant[value=0]()
%10 : Double(100) = aten::index_put_(%target, %8, %5, %9)
return (%10);
}

View File

@ -4,6 +4,7 @@ graph(%target : Double(100)
%3 : bool = prim::Constant[value=0]()
%indices : Long(4) = aten::_cast_Long(%indices.1, %3)
%5 : Dynamic[] = prim::ListConstruct(%indices)
%6 : Double(100) = aten::index_put_(%target, %5, %rhs)
return (%6);
%6 : bool = prim::Constant[value=0]()
%7 : Double(100) = aten::index_put_(%target, %5, %rhs, %6)
return (%7);
}

View File

@ -28,6 +28,7 @@ TESTS = [
'distributions',
'expecttest',
'indexing',
'indexing_cuda',
'jit',
'multiprocessing',
'multiprocessing_spawn',

View File

@ -926,7 +926,7 @@ class TestCuda(TestCase):
self.assertEqual(x * y, 4.5)
self.assertEqual(y * x, 4.5)
with self.assertRaisesRegex(RuntimeError, 'expected type'):
with self.assertRaisesRegex(RuntimeError, "doesn't match the desired type"):
y *= x
x *= y
self.assertEqual(x, 4.5)

View File

@ -2,6 +2,7 @@ from common_utils import TestCase, run_tests
import torch
import warnings
from torch import tensor
import unittest
class TestIndexing(TestCase):
@ -448,9 +449,9 @@ class NumpyTests(TestCase):
def f(a, v):
a[a > -1] = tensor(v)
self.assertRaisesRegex(Exception, "expand", f, a, [])
self.assertRaisesRegex(Exception, 'expand', f, a, [1, 2, 3])
self.assertRaisesRegex(Exception, 'expand', f, a[:1], [1, 2, 3])
self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [])
self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [1, 2, 3])
self.assertRaisesRegex(Exception, 'shape mismatch', f, a[:1], [1, 2, 3])
def test_boolean_indexing_twodim(self):
# Indexing a 2-dimensional array with
@ -503,12 +504,14 @@ class NumpyTests(TestCase):
def test_broaderrors_indexing(self):
a = torch.zeros(5, 5)
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
def test_trivial_fancy_out_of_bounds(self):
a = torch.zeros(5)
ind = torch.ones(20, dtype=torch.int64)
if a.is_cuda:
raise unittest.SkipTest('CUDA asserts instead of raising an exception')
ind[-1] = 10
self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)

View File

@ -0,0 +1,10 @@
import torch
from test_indexing import *
if __name__ == '__main__':
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
run_tests()
else:
print("Skipping test_indexing_cuda.py")

View File

@ -360,6 +360,10 @@
- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
self: not_implemented("histc")
- name: index(Tensor self, TensorList indices)
self: zeros_like(self).index_put_(indices, grad, true)
indices: TensorList()
- name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
self: grad
source: grad.index_select(dim, index)
@ -375,6 +379,10 @@
self: grad.clone().index_fill_(dim, index, 0)
value: grad.index_select(dim, index).sum()
- name: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate)
self: grad.clone().index_put_(indices, zeros_like(values), accumulate)
values: grad.index(indices)
- name: index_select(Tensor self, int64_t dim, Tensor index)
self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)

View File

@ -222,8 +222,7 @@ std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name
for (size_t i = 0; i < tl.size(); ++i) {
const auto &t = tl[i];
if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor at position #", i, " "
"for iterable argument #", pos, " '", name, "'");
continue;
}
if (!isVariableType(t.type())) {
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " at position #", i, " "

View File

@ -135,6 +135,12 @@ inline void check_no_requires_grad(const Tensor& tensor, const char* name) {
}
}
inline void check_no_requires_grad(TensorList tensors, const char* name) {
for (auto& tensor : tensors) {
check_no_requires_grad(tensor, name);
}
}
// Assumed that saved tensor lists are never inplace outputs
inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {

View File

@ -811,9 +811,9 @@ def index_select(g, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
def index_put(g, self, indices_list_value, values):
def index_put(g, self, indices_list_value, values, accumulate):
indices_list = _unpack_list(indices_list_value)
args = [self] + indices_list + [values]
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s='index_put')