mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Speed-up "advanced" indexing operations (#13420)
Summary: This speeds-up "advanced" indexing (indexing a tensor by a tensor) on CPU and GPU. There's still a bunch of work to do, including speeding up indexing by a byte (boolean) mask and speeding up the derivative calculation for advanced indexing. Here's some speed comparisons to indexing on master using a little [benchmark script](https://gist.github.com/colesbury/c369db72aad594e5e032c8fda557d909) with 16 OpenMP threads and on a P100. The test cases are listed as (input shape -> output shape). | Test case | CPU (old vs. new) | CUDA (old vs. new) | |-----------------------|---------------------|------------------------| | 1024x1024 -> 512x1024 | 225 us vs. **57 us** | 297 us vs. **47 us** | | 1024x1024 -> 1024x512 | 208 us vs. **153 us** | 335 us vs. **54 us** | | 50x50 -> 20000x50 | 617 us vs. **77 us** | 239 us vs. **54 us** | | 50x50 -> 50x20000 | 575 us vs. **236 us** | 262 us vs. **58 us** | | 2x5x10 -> 10 | 65 us vs. **18 us** | 612 us vs. **93 us** | See #11647 Pull Request resolved: https://github.com/pytorch/pytorch/pull/13420 Reviewed By: soumith Differential Revision: D13088936 Pulled By: colesbury fbshipit-source-id: 0a5c2ee9aa54e15f96d06692d1694c3b24b924e2
This commit is contained in:
committed by
Facebook Github Bot
parent
0199d59d3a
commit
006505bb8f
@ -356,8 +356,8 @@ public:
|
||||
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const;
|
||||
Tensor index(TensorList indices) const;
|
||||
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
|
||||
Tensor index_put(TensorList indices, const Tensor & values) const;
|
||||
Tensor & index_put_(TensorList indices, const Tensor & values);
|
||||
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
|
||||
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
|
||||
Tensor inverse() const;
|
||||
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
|
||||
bool is_distributed() const;
|
||||
|
@ -318,11 +318,11 @@ inline Tensor Tensor::index(TensorList indices) const {
|
||||
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
return type().index_copy_(*this, dim, index, source);
|
||||
}
|
||||
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values) const {
|
||||
return type().index_put(*this, indices, values);
|
||||
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
|
||||
return type().index_put(*this, indices, values, accumulate);
|
||||
}
|
||||
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values) {
|
||||
return type().index_put_(*this, indices, values);
|
||||
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
|
||||
return type().index_put_(*this, indices, values, accumulate);
|
||||
}
|
||||
inline Tensor Tensor::inverse() const {
|
||||
return type().inverse(*this);
|
||||
|
@ -264,8 +264,8 @@ struct CAFFE2_API Type {
|
||||
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0;
|
||||
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
|
||||
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
|
||||
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values) const = 0;
|
||||
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values) const = 0;
|
||||
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
|
||||
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
|
||||
virtual Tensor inverse(const Tensor & self) const = 0;
|
||||
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
|
||||
virtual bool is_distributed(const Tensor & self) const = 0;
|
||||
|
@ -17,6 +17,9 @@ struct OffsetCalculator {
|
||||
using offset_type = at::cuda::Array<uint32_t, NARGS>;
|
||||
|
||||
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
|
||||
if (dims > MAX_DIMS) {
|
||||
throw std::runtime_error("tensor has too many (>25) dims");
|
||||
}
|
||||
for (int i = 0; i < MAX_DIMS; ++i) {
|
||||
if (i < dims) {
|
||||
sizes_[i] = IntDivider<uint32_t>(sizes[i]);
|
||||
|
@ -3,7 +3,7 @@
|
||||
// This corresponds to "advanced indexing" in NumPy. The two operations are:
|
||||
//
|
||||
// index(Tensor self, indices) -> Tensor
|
||||
// index_put_(Tensor self, indices, value)
|
||||
// index_put_(Tensor self, indices, value, accumulate=false)
|
||||
//
|
||||
// The index is a TensorList containg kLong or kByte tensors or nulls. Byte
|
||||
// tensors (boolean masks) are expanded to long tensors via nonzero(). Null
|
||||
@ -20,11 +20,40 @@
|
||||
// Note 2: The behavior is more complicated when the index tensors are not all
|
||||
// adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
|
||||
// tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
|
||||
//
|
||||
// The code contains two implementations of indexing. The more efficient
|
||||
// implementation treats indexing like an elementwise operation over the
|
||||
// tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
|
||||
// not work for index_put_ with accumulate=True. The other implementation
|
||||
// combines the indexed tensors into a single linear index that is used
|
||||
// with Tensor.put_. This is used for index_put_ with accumulate=True.
|
||||
//
|
||||
// The more efficient implementation takes the following steps for the
|
||||
// above operation:
|
||||
//
|
||||
// 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
|
||||
// 2) Record x.stride(i) for each indexed dimension `i`
|
||||
// 3) Replace the indexed subspace of `x` with the shape of the corresponding
|
||||
// subspace of `result` but with stride 0
|
||||
// 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
|
||||
// that their shape is compatible with the result shape
|
||||
//
|
||||
// The CPU or CUDA kernel then computes element-wise over the broadcasted
|
||||
// and restrided result, x, ind_1, ind_2, etc.:
|
||||
//
|
||||
// result[...] = *(&x[...] +
|
||||
// ind_1[...] * x.stride(1) +
|
||||
// ind_2[...] * x.stride(2) +
|
||||
// ...)
|
||||
//
|
||||
// where & and * represent the C-style address-of and indirection operations.
|
||||
|
||||
#include <ATen/native/Indexing.h>
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/NativeFunctions.h"
|
||||
#include "ATen/ExpandUtils.h"
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
@ -33,6 +62,9 @@
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
DEFINE_DISPATCH(index_stub);
|
||||
DEFINE_DISPATCH(index_put_stub);
|
||||
|
||||
[[noreturn]]
|
||||
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
|
||||
std::stringstream ss;
|
||||
@ -226,34 +258,188 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
|
||||
return std::make_tuple(self, linearIndex);
|
||||
}
|
||||
|
||||
static bool all_strides_match(TensorList tensors) {
|
||||
AT_ASSERT(tensors.size() >= 1);
|
||||
auto strides = tensors[0].strides();
|
||||
for (auto& tensor : tensors.slice(1)) {
|
||||
if (!strides.equals(tensor.strides())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::string shapes_as_str(TensorList tensors) {
|
||||
std::ostringstream os;
|
||||
bool first = true;
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor.defined()) {
|
||||
if (!first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << tensor.sizes();
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
struct AdvancedIndex {
|
||||
AdvancedIndex(const Tensor& src, TensorList indices);
|
||||
|
||||
Tensor src;
|
||||
std::vector<Tensor> indices;
|
||||
DimVector indexed_sizes;
|
||||
DimVector indexed_strides;
|
||||
int64_t dims_before;
|
||||
int64_t dims_after;
|
||||
};
|
||||
|
||||
// Replace indexed dimensions in src with stride 0 and the size of the result tensor.
|
||||
// The offset in these dimensions is computed by the kernel using the index tensor's
|
||||
// values and the stride of src. The new shape is not meaningful. It's used to make
|
||||
// the shape compatible with the result tensor.
|
||||
static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
|
||||
IntList replacement_shape) {
|
||||
auto shape = DimVector(src.sizes());
|
||||
auto strides = DimVector(src.strides());
|
||||
int end = dims_before + dims_indexed;
|
||||
shape.erase(shape.begin() + dims_before, shape.begin() + end);
|
||||
strides.erase(strides.begin() + dims_before, strides.begin() + end);
|
||||
shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
|
||||
strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
|
||||
return src.as_strided(shape, strides);
|
||||
}
|
||||
|
||||
// Add dimensions of size 1 to an index tensor so that it's can be broadcast to the result
|
||||
// shape and iterated over element-wise like the result tensor and the restrided src.
|
||||
static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
|
||||
auto orig_shape = index.sizes();
|
||||
auto shape = DimVector();
|
||||
shape.append(dims_before, 1);
|
||||
shape.append(orig_shape.begin(), orig_shape.end());
|
||||
shape.append(dims_after, 1);
|
||||
return index.reshape(shape);
|
||||
}
|
||||
|
||||
AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
|
||||
{
|
||||
int64_t element_size_bytes = src.type().elementSizeInBytes();
|
||||
int dims_before = 0, dims_after = 0, dims_indexed = 0;
|
||||
IntList replacement_shape;
|
||||
for (size_t dim = 0; dim < indices_list.size(); dim++) {
|
||||
if (!indices_list[dim].defined()) {
|
||||
if (dims_indexed == 0) {
|
||||
dims_before++;
|
||||
} else {
|
||||
dims_after++;
|
||||
}
|
||||
} else {
|
||||
dims_indexed++;
|
||||
replacement_shape = indices_list[dim].sizes();
|
||||
indexed_sizes.push_back(src.size(dim));
|
||||
indexed_strides.push_back(src.stride(dim) * element_size_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
this->dims_before = dims_before;
|
||||
this->dims_after = dims_after;
|
||||
this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
|
||||
|
||||
for (auto& index : indices_list) {
|
||||
if (index.defined()) {
|
||||
indices.push_back(reshape_indexer(index, dims_before, dims_after));
|
||||
}
|
||||
}
|
||||
|
||||
// For CUDA tensors, force all index tensors to have the same striding to
|
||||
// simplify the CUDA kernel.
|
||||
if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
|
||||
if (!all_strides_match(indices)) {
|
||||
for (size_t i = 0; i < indices.size(); i++) {
|
||||
indices[i] = indices[i].contiguous();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static AdvancedIndex make_info(Tensor self, TensorList orig) {
|
||||
checkIndexTensorTypes(orig);
|
||||
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
|
||||
auto indices = expandByteTensors(self, orig);
|
||||
// next broadcast all index tensors together
|
||||
try {
|
||||
indices = expand_outplace(indices);
|
||||
} catch (std::exception& e) {
|
||||
AT_ERROR("shape mismatch: indexing tensors could not be broadcast together"
|
||||
" with shapes ", shapes_as_str(indices));
|
||||
}
|
||||
// add missing null Tensors so that it matches self.dim()
|
||||
while (indices.size() < (size_t)self.dim()) {
|
||||
indices.emplace_back();
|
||||
}
|
||||
// if the non-null indices are not all adjacent, transpose self and indices
|
||||
// together so that they're adjacent at the front
|
||||
if (!hasContiguousSubspace(indices)) {
|
||||
std::tie(self, indices) = transposeToFront(self, indices);
|
||||
}
|
||||
return AdvancedIndex(self, indices);
|
||||
}
|
||||
|
||||
static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
|
||||
auto builder = TensorIterator::Builder();
|
||||
builder.dont_compute_common_dtype();
|
||||
builder.add_output(Tensor(), &info.src.type());
|
||||
builder.add_input(info.src);
|
||||
for (auto& index : info.indices) {
|
||||
builder.add_input(index);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
|
||||
if (!is_expandable_to(value.sizes(), info.src.sizes())) {
|
||||
AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(),
|
||||
" cannot be broadcast to indexing result of shape ", info.src.sizes());
|
||||
}
|
||||
auto builder = TensorIterator::Builder();
|
||||
builder.dont_compute_common_dtype();
|
||||
builder.dont_resize_outputs();
|
||||
builder.add_output(info.src);
|
||||
builder.add_input(value, &info.src.type());
|
||||
for (auto& index : info.indices) {
|
||||
builder.add_input(index);
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
Tensor index(const Tensor & self, TensorList indices) {
|
||||
AT_CHECK(indices.size() <= (size_t)self.dim(),
|
||||
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
|
||||
Tensor src, linearIndex;
|
||||
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
|
||||
return src.take(linearIndex);
|
||||
auto info = make_info(self, indices);
|
||||
auto iter = make_index_iterator(info);
|
||||
index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides);
|
||||
return iter->output();
|
||||
}
|
||||
|
||||
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value) {
|
||||
AT_CHECK(indices.size() <= (size_t)self.dim(),
|
||||
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
|
||||
Tensor src, linearIndex, expandedValue;
|
||||
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
|
||||
std::tie(expandedValue) = expand_inplace(linearIndex, value);
|
||||
Tensor dst = src.clone();
|
||||
return dst.put_(linearIndex, expandedValue);
|
||||
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
|
||||
return self.clone().index_put_(indices, value, accumulate);
|
||||
}
|
||||
|
||||
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
|
||||
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
|
||||
AT_CHECK(indices.size() <= (size_t)self.dim(),
|
||||
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
|
||||
Tensor src, linearIndex, expandedValue;
|
||||
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
|
||||
std::tie(expandedValue) = expand_inplace(linearIndex, value);
|
||||
return src.put_(linearIndex, expandedValue);
|
||||
if (accumulate && self.type().device_type() == kCUDA) {
|
||||
Tensor src, linearIndex, expandedValue;
|
||||
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
|
||||
std::tie(expandedValue) = expand_inplace(linearIndex, value);
|
||||
return src.put_(linearIndex, expandedValue, true);
|
||||
}
|
||||
auto info = make_info(self, indices);
|
||||
auto iter = make_index_put_iterator(info, value);
|
||||
index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
|
||||
|
20
aten/src/ATen/native/Indexing.h
Normal file
20
aten/src/ATen/native/Indexing.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
// Indexing tensors by by tensors
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
struct TensorIterator;
|
||||
}
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
using index_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides);
|
||||
using index_put_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides, bool accumulate);
|
||||
|
||||
DECLARE_DISPATCH(index_fn, index_stub);
|
||||
DECLARE_DISPATCH(index_put_fn, index_put_stub);
|
||||
|
||||
}} // namespace at::native
|
@ -97,21 +97,47 @@ compute_result_type(at::ArrayRef<OperandInfo> operands, const F& predicate) {
|
||||
return std::make_tuple(result_type, backend);
|
||||
}
|
||||
|
||||
static bool needs_cast(const Tensor& tensor, const Type& dst_type) {
|
||||
if (!tensor.defined() || dst_type == tensor.type()) {
|
||||
return false;
|
||||
void TensorIterator::compute_types() {
|
||||
bool missing_dtypes = false;
|
||||
for (auto& op : operands_) {
|
||||
if (!op.tensor.defined() && !op.type) {
|
||||
missing_dtypes = true;
|
||||
}
|
||||
}
|
||||
if (dst_type.device_type() == DeviceType::CUDA &&
|
||||
tensor.type().device_type() == DeviceType::CPU &&
|
||||
tensor.dim() == 0) {
|
||||
// zero-dim CPU tensors used in CUDA operations can be used directly without
|
||||
// casting
|
||||
return false;
|
||||
|
||||
if (missing_dtypes || compute_common_dtype_) {
|
||||
auto& type = compute_common_type();
|
||||
for (auto& op : operands_) {
|
||||
if (!op.type) {
|
||||
op.type = &type;
|
||||
} else if (compute_common_dtype_ && op.type != &type) {
|
||||
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
|
||||
type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
|
||||
// don't cast CPU scalars in CUDA ops that directly support them
|
||||
op.type = &op.tensor.type();
|
||||
} else {
|
||||
op.type = &type;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op : operands_) {
|
||||
if (op.tensor.defined() && op.tensor.type() != *op.type) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with type ", op.tensor.type().toString(),
|
||||
" doesn't match the desired type ", type().toString());
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.toType(*op.type);
|
||||
} else {
|
||||
AT_ERROR("expected type ", type().toString(), " but got ",
|
||||
op.tensor.type().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorIterator::compute_common_type() {
|
||||
Type& TensorIterator::compute_common_type() {
|
||||
// See [Result type computation] in TensorIterator.h
|
||||
auto result_type = ScalarType::Undefined;
|
||||
auto backend = Backend::Undefined;
|
||||
@ -132,18 +158,7 @@ void TensorIterator::compute_common_type() {
|
||||
AT_ASSERT(result_type != ScalarType::Undefined);
|
||||
AT_ASSERT(backend != Backend::Undefined);
|
||||
|
||||
auto& type = at::globalContext().getNonVariableType(backend, result_type);
|
||||
|
||||
for (auto& op : operands_) {
|
||||
if (!op.type) {
|
||||
op.type = &type;
|
||||
op.needs_cast = needs_cast(op.tensor, type);
|
||||
if (op.needs_cast && op.tensor.dim() == 0 && !op.is_output) {
|
||||
op.tensor = op.tensor.toType(type);
|
||||
op.needs_cast = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return at::globalContext().getNonVariableType(backend, result_type);
|
||||
}
|
||||
|
||||
DimVector TensorIterator::compatible_stride(int element_size) const {
|
||||
@ -171,6 +186,7 @@ void TensorIterator::allocate_outputs() {
|
||||
for (int i = 0; i < num_outputs_; i++) {
|
||||
auto& op = operands_[i];
|
||||
if (!op.tensor.defined()) {
|
||||
AT_ASSERTM(op.type, "no type for operand", i);
|
||||
int element_size = op.type->elementSizeInBytes();
|
||||
op.stride_bytes = compatible_stride(element_size);
|
||||
|
||||
@ -405,7 +421,7 @@ bool TensorIterator::is_scalar(int arg) const {
|
||||
}
|
||||
|
||||
bool TensorIterator::is_cpu_scalar(int arg) const {
|
||||
return is_scalar(arg) && operands_[arg].tensor.type().backend() == at::Backend::CPU;
|
||||
return is_scalar(arg) && operands_[arg].tensor.type().device_type() == kCPU;
|
||||
}
|
||||
|
||||
void* TensorIterator::data_ptr(int arg) const {
|
||||
@ -450,6 +466,7 @@ std::unique_ptr<TensorIterator> TensorIterator::binary_op(Tensor& out, const Ten
|
||||
builder.add_output(out);
|
||||
builder.add_input(a);
|
||||
builder.add_input(b);
|
||||
builder.iter_->allow_cpu_scalars_ = true;
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@ -459,6 +476,7 @@ std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Ten
|
||||
builder.add_output(out);
|
||||
builder.add_input(a);
|
||||
builder.iter_->resize_outputs_ = false;
|
||||
builder.iter_->is_reduction_ = true;
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@ -485,7 +503,7 @@ void TensorIterator::compute_shape() {
|
||||
// For now, don't include output tensors that are not also input tensors.
|
||||
// This preserves the legacy behavior where torch.add(..., out=dst) resizes
|
||||
// the destination tensor.
|
||||
if (op.is_output && !op.is_read_write) continue;
|
||||
if (resize_outputs_ && op.is_output && !op.is_read_write) continue;
|
||||
|
||||
auto shape = op.tensor.sizes();
|
||||
if (shape_.empty()) {
|
||||
@ -501,15 +519,17 @@ void TensorIterator::compute_shape() {
|
||||
// outputs.
|
||||
for (int i = 0; i < num_outputs_; i++) {
|
||||
auto& tensor = operands_[i].tensor;
|
||||
if (resize_outputs_ && tensor.defined() && !tensor.sizes().equals(shape_)) {
|
||||
if (!operands_[i].is_read_write) {
|
||||
if (tensor.defined() && !tensor.sizes().equals(shape_)) {
|
||||
if (resize_outputs_ && !operands_[i].is_read_write) {
|
||||
// Preserve legacy resizing behavior of out=... arguments
|
||||
// TODO: issue warning
|
||||
tensor.resize_(shape_);
|
||||
continue;
|
||||
}
|
||||
AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
|
||||
shape_);
|
||||
if (!is_reduction_) {
|
||||
AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
|
||||
shape_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -540,15 +560,6 @@ void TensorIterator::compute_strides() {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorIterator::check_type_conversions() {
|
||||
for (auto& op : operands_) {
|
||||
if (op.needs_cast) {
|
||||
AT_ERROR("TensorIterator expected type ", type().toString(), " but got ",
|
||||
op.tensor.type().toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorIterator::can_use_32bit_indexing() const {
|
||||
int64_t max_value = std::numeric_limits<int32_t>::max();
|
||||
if (numel() > max_value) {
|
||||
@ -612,7 +623,7 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
|
||||
// re-order dimensions to improve coalescing
|
||||
iter_->reorder_dimensions();
|
||||
// compute the result dtype and backend
|
||||
iter_->compute_common_type();
|
||||
iter_->compute_types();
|
||||
// allocate the output tensor if it's not provided
|
||||
iter_->allocate_outputs();
|
||||
// coalesce adjacent dimensions when possible
|
||||
@ -623,8 +634,6 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
|
||||
op.data = op.tensor.data_ptr();
|
||||
}
|
||||
|
||||
iter_->check_type_conversions();
|
||||
|
||||
return std::move(iter_);
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,12 @@ namespace at {
|
||||
|
||||
struct CAFFE2_API OperandInfo {
|
||||
OperandInfo() {}
|
||||
OperandInfo(const Tensor& t) : tensor(t) {}
|
||||
OperandInfo(const Tensor& t, const Type* type=nullptr)
|
||||
: tensor(t), type(const_cast<Type*>(type)) {
|
||||
if (t.defined() && !type) {
|
||||
this->type = &t.type();
|
||||
}
|
||||
}
|
||||
|
||||
/// Stride after broadcasting. The stride is in bytes, not number of elements.
|
||||
DimVector stride_bytes;
|
||||
@ -74,9 +79,6 @@ struct CAFFE2_API OperandInfo {
|
||||
/// iterator is split.
|
||||
void* data = nullptr;
|
||||
|
||||
/// True if the kernel needs to handle a cast operation for this operand.
|
||||
bool needs_cast = false;
|
||||
|
||||
bool is_output = false;
|
||||
|
||||
bool is_read_write = false;
|
||||
@ -210,10 +212,10 @@ protected:
|
||||
void compute_strides();
|
||||
void reorder_dimensions();
|
||||
void permute_dimensions(IntList perm);
|
||||
void compute_common_type();
|
||||
void compute_types();
|
||||
Type& compute_common_type();
|
||||
void allocate_outputs();
|
||||
void coalesce_dimensions();
|
||||
void check_type_conversions();
|
||||
|
||||
protected:
|
||||
DimVector shape_;
|
||||
@ -223,6 +225,9 @@ protected:
|
||||
bool has_coalesced_dimensions_ = false;
|
||||
bool accumulate_ = false;
|
||||
bool resize_outputs_ = true;
|
||||
bool is_reduction_ = false;
|
||||
bool compute_common_dtype_ = true;
|
||||
bool allow_cpu_scalars_ = false;
|
||||
};
|
||||
|
||||
struct TensorIterator::Builder {
|
||||
@ -230,15 +235,21 @@ struct TensorIterator::Builder {
|
||||
|
||||
Builder() : iter_(new TensorIterator()) {};
|
||||
|
||||
Builder& add_output(const Tensor& output) {
|
||||
iter_->operands_.emplace_back(output);
|
||||
void add_output(const Tensor& output, const Type* type=nullptr) {
|
||||
iter_->operands_.emplace_back(output, type);
|
||||
iter_->num_outputs_++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder& add_input(const Tensor& input) {
|
||||
iter_->operands_.emplace_back(input);
|
||||
return *this;
|
||||
void add_input(const Tensor& input, const Type* type=nullptr) {
|
||||
iter_->operands_.emplace_back(input, type);
|
||||
}
|
||||
|
||||
void dont_compute_common_dtype() {
|
||||
iter_->compute_common_dtype_ = false;
|
||||
}
|
||||
|
||||
void dont_resize_outputs() {
|
||||
iter_->resize_outputs_ = false;
|
||||
}
|
||||
|
||||
std::unique_ptr<TensorIterator> build();
|
||||
|
125
aten/src/ATen/native/cpu/IndexKernel.cpp
Normal file
125
aten/src/ATen/native/cpu/IndexKernel.cpp
Normal file
@ -0,0 +1,125 @@
|
||||
#include <ATen/native/Indexing.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec256/vec256.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
using namespace vec256;
|
||||
|
||||
struct Indexer {
|
||||
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
|
||||
IntList original_sizes, IntList original_strides)
|
||||
: num_indexers(num_indexers)
|
||||
, indexers(indexers)
|
||||
, indexer_strides(indexer_strides)
|
||||
, original_strides(original_strides.data())
|
||||
, original_sizes(original_sizes.data()) {
|
||||
AT_ASSERT(original_strides.size() == num_indexers);
|
||||
AT_ASSERT(original_sizes.size() == num_indexers);
|
||||
}
|
||||
|
||||
int64_t num_indexers;
|
||||
char** indexers;
|
||||
const int64_t* indexer_strides;
|
||||
const int64_t* original_strides;
|
||||
const int64_t* original_sizes;
|
||||
|
||||
int64_t get(int64_t idx) {
|
||||
int64_t offset = 0;
|
||||
for (int j = 0; j < num_indexers; j++) {
|
||||
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
|
||||
int64_t size = original_sizes[j];
|
||||
if (value < -size || value >= size) {
|
||||
AT_ERROR("index ", value, " is out of bounds for dim with size ", size);
|
||||
}
|
||||
if (value < 0) {
|
||||
value += size;
|
||||
}
|
||||
offset += value * original_strides[j];
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
static bool is_constant_index(int ntensor, const int64_t* strides) {
|
||||
AT_ASSERT(ntensor >= 3);
|
||||
for (int arg = 2; arg < ntensor; arg++) {
|
||||
if (strides[arg] != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename func_t>
|
||||
void cpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride,
|
||||
const func_t& f, bool serial_execution=false)
|
||||
{
|
||||
auto loop = [&](int ntensor, char** data, const int64_t* strides, int64_t n) {
|
||||
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
|
||||
char* dst = data[0];
|
||||
char* src = data[1];
|
||||
if (is_constant_index(ntensor, strides)) {
|
||||
// specialization for when every element uses the same index
|
||||
int64_t offset = indexer.get(0);
|
||||
if (strides[0] == sizeof(scalar_t) && strides[1] == sizeof(scalar_t)) {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
int64_t offset = indexer.get(i);
|
||||
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
if (serial_execution) {
|
||||
iter.serial_for_each(loop, {0, iter.numel()});
|
||||
} else {
|
||||
iter.for_each(loop);
|
||||
}
|
||||
}
|
||||
|
||||
void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(0), "index", [&] {
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)dst = *(scalar_t*)(src + offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
|
||||
// NOTE: duplicate indices are only supported if accumulate is true.
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(0), "index_put", [&] {
|
||||
if (accumulate) {
|
||||
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
|
||||
// this needs to be thread-safe.
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset) += *(scalar_t*)src;
|
||||
}, /*serial_execution=*/true);
|
||||
} else {
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
REGISTER_DISPATCH(index_stub, &index_kernel);
|
||||
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
|
||||
|
||||
}} // namespace at::native
|
102
aten/src/ATen/native/cuda/IndexKernel.cu
Normal file
102
aten/src/ATen/native/cuda/IndexKernel.cu
Normal file
@ -0,0 +1,102 @@
|
||||
#include <ATen/native/Indexing.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/cuda/Array.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
template <int N>
|
||||
static OffsetCalculator<N> index_make_offset_calculator(const TensorIterator& iter) {
|
||||
AT_ASSERT(N <= iter.ntensors());
|
||||
std::array<const int64_t*, N> strides;
|
||||
for (int i = 0; i < N; i++) {
|
||||
strides[i] = iter.strides(i).data();
|
||||
}
|
||||
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data());
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, const func_t& f) {
|
||||
int num_indices = index_size.size();
|
||||
AT_ASSERT(num_indices == index_stride.size());
|
||||
AT_ASSERT(num_indices == iter.ntensors() - 2);
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto sizes = cuda::Array<int64_t, 25>(0);
|
||||
auto strides = cuda::Array<int64_t, 25>(0);
|
||||
auto index_ptrs = cuda::Array<char*, 25>(nullptr);
|
||||
for (int i = 0; i < num_indices; i++) {
|
||||
sizes[i] = index_size[i];
|
||||
strides[i] = index_stride[i];
|
||||
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
|
||||
}
|
||||
|
||||
char* out_ptr = (char*)iter.data_ptr(0);
|
||||
char* in_ptr = (char*)iter.data_ptr(1);
|
||||
|
||||
auto offset_calc = index_make_offset_calculator<3>(iter);
|
||||
launch_kernel<128, 4>(iter.numel(), [=]__device__(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
char* out_data = out_ptr + offsets[0];
|
||||
char* in_data = in_ptr + offsets[1];
|
||||
|
||||
int64_t offset = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_indices; i++) {
|
||||
int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
|
||||
assert(index >= -sizes[i] && index < sizes[i] && "index out of bounds");
|
||||
if (index < 0) {
|
||||
index += sizes[i];
|
||||
}
|
||||
offset += index * strides[i];
|
||||
}
|
||||
|
||||
f(out_data, in_data, offset);
|
||||
});
|
||||
}
|
||||
|
||||
// The kernels are templated on an opaque, self-aligned type of the correct
|
||||
// size to avoid redundant kernels for different types of the same size.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
void index_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
|
||||
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
|
||||
*(scalar_t*)out_data = *(scalar_t*)(in_data + offset);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void index_put_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
|
||||
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
|
||||
*(scalar_t*)(out_data + offset) = *(scalar_t*)in_data;
|
||||
});
|
||||
}
|
||||
|
||||
static void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
|
||||
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
|
||||
AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(index_stub, &index_kernel);
|
||||
REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
|
||||
|
||||
}} // namespace at::native
|
@ -888,10 +888,10 @@
|
||||
- func: index_copy_(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
|
||||
variants: method
|
||||
|
||||
- func: index_put(Tensor self, TensorList indices, Tensor values) -> Tensor
|
||||
- func: index_put(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
|
||||
- func: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) -> Tensor
|
||||
|
@ -7,6 +7,7 @@ graph(%target : Double(100)
|
||||
%6 : bool = prim::Constant[value=0]()
|
||||
%indices : Long(4) = aten::_cast_Long(%indices.1, %6)
|
||||
%8 : Dynamic[] = prim::ListConstruct(%indices)
|
||||
%9 : Double(100) = aten::index_put_(%target, %8, %5)
|
||||
return (%9);
|
||||
%9 : bool = prim::Constant[value=0]()
|
||||
%10 : Double(100) = aten::index_put_(%target, %8, %5, %9)
|
||||
return (%10);
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ graph(%target : Double(100)
|
||||
%3 : bool = prim::Constant[value=0]()
|
||||
%indices : Long(4) = aten::_cast_Long(%indices.1, %3)
|
||||
%5 : Dynamic[] = prim::ListConstruct(%indices)
|
||||
%6 : Double(100) = aten::index_put_(%target, %5, %rhs)
|
||||
return (%6);
|
||||
%6 : bool = prim::Constant[value=0]()
|
||||
%7 : Double(100) = aten::index_put_(%target, %5, %rhs, %6)
|
||||
return (%7);
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ TESTS = [
|
||||
'distributions',
|
||||
'expecttest',
|
||||
'indexing',
|
||||
'indexing_cuda',
|
||||
'jit',
|
||||
'multiprocessing',
|
||||
'multiprocessing_spawn',
|
||||
|
@ -926,7 +926,7 @@ class TestCuda(TestCase):
|
||||
|
||||
self.assertEqual(x * y, 4.5)
|
||||
self.assertEqual(y * x, 4.5)
|
||||
with self.assertRaisesRegex(RuntimeError, 'expected type'):
|
||||
with self.assertRaisesRegex(RuntimeError, "doesn't match the desired type"):
|
||||
y *= x
|
||||
x *= y
|
||||
self.assertEqual(x, 4.5)
|
||||
|
@ -2,6 +2,7 @@ from common_utils import TestCase, run_tests
|
||||
import torch
|
||||
import warnings
|
||||
from torch import tensor
|
||||
import unittest
|
||||
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
@ -448,9 +449,9 @@ class NumpyTests(TestCase):
|
||||
def f(a, v):
|
||||
a[a > -1] = tensor(v)
|
||||
|
||||
self.assertRaisesRegex(Exception, "expand", f, a, [])
|
||||
self.assertRaisesRegex(Exception, 'expand', f, a, [1, 2, 3])
|
||||
self.assertRaisesRegex(Exception, 'expand', f, a[:1], [1, 2, 3])
|
||||
self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [])
|
||||
self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [1, 2, 3])
|
||||
self.assertRaisesRegex(Exception, 'shape mismatch', f, a[:1], [1, 2, 3])
|
||||
|
||||
def test_boolean_indexing_twodim(self):
|
||||
# Indexing a 2-dimensional array with
|
||||
@ -503,12 +504,14 @@ class NumpyTests(TestCase):
|
||||
|
||||
def test_broaderrors_indexing(self):
|
||||
a = torch.zeros(5, 5)
|
||||
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
|
||||
def test_trivial_fancy_out_of_bounds(self):
|
||||
a = torch.zeros(5)
|
||||
ind = torch.ones(20, dtype=torch.int64)
|
||||
if a.is_cuda:
|
||||
raise unittest.SkipTest('CUDA asserts instead of raising an exception')
|
||||
ind[-1] = 10
|
||||
self.assertRaises(RuntimeError, a.__getitem__, ind)
|
||||
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
|
||||
|
10
test/test_indexing_cuda.py
Normal file
10
test/test_indexing_cuda.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torch
|
||||
from test_indexing import *
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if torch.cuda.is_available():
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
run_tests()
|
||||
else:
|
||||
print("Skipping test_indexing_cuda.py")
|
@ -360,6 +360,10 @@
|
||||
- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
|
||||
self: not_implemented("histc")
|
||||
|
||||
- name: index(Tensor self, TensorList indices)
|
||||
self: zeros_like(self).index_put_(indices, grad, true)
|
||||
indices: TensorList()
|
||||
|
||||
- name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
|
||||
self: grad
|
||||
source: grad.index_select(dim, index)
|
||||
@ -375,6 +379,10 @@
|
||||
self: grad.clone().index_fill_(dim, index, 0)
|
||||
value: grad.index_select(dim, index).sum()
|
||||
|
||||
- name: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate)
|
||||
self: grad.clone().index_put_(indices, zeros_like(values), accumulate)
|
||||
values: grad.index(indices)
|
||||
|
||||
- name: index_select(Tensor self, int64_t dim, Tensor index)
|
||||
self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)
|
||||
|
||||
|
@ -222,8 +222,7 @@ std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name
|
||||
for (size_t i = 0; i < tl.size(); ++i) {
|
||||
const auto &t = tl[i];
|
||||
if (!t.defined()) {
|
||||
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor at position #", i, " "
|
||||
"for iterable argument #", pos, " '", name, "'");
|
||||
continue;
|
||||
}
|
||||
if (!isVariableType(t.type())) {
|
||||
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " at position #", i, " "
|
||||
|
@ -135,6 +135,12 @@ inline void check_no_requires_grad(const Tensor& tensor, const char* name) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void check_no_requires_grad(TensorList tensors, const char* name) {
|
||||
for (auto& tensor : tensors) {
|
||||
check_no_requires_grad(tensor, name);
|
||||
}
|
||||
}
|
||||
|
||||
// Assumed that saved tensor lists are never inplace outputs
|
||||
inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
|
||||
return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
|
||||
|
@ -811,9 +811,9 @@ def index_select(g, self, dim, index):
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
|
||||
def index_put(g, self, indices_list_value, values):
|
||||
def index_put(g, self, indices_list_value, values, accumulate):
|
||||
indices_list = _unpack_list(indices_list_value)
|
||||
args = [self] + indices_list + [values]
|
||||
args = [self] + indices_list + [values, accumulate]
|
||||
return g.op("ATen", *args, operator_s='index_put')
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user