mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
C++ tensor indexing: add Slice / TensorIndex (#30424)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30424 `at::indexing::TensorIndex` is used for converting C++ tensor indices such as `{None, "...", Ellipsis, 0, true, {1, None, 2}, torch::tensor({1, 2})}` into its equivalent `std::vector<TensorIndex>`, so that further tensor indexing operations can be performed using the supplied indices. Test Plan: Imported from OSS Differential Revision: D18695902 Pulled By: yf225 fbshipit-source-id: d73e14a411cdbec815866b02e75ffd71a9186e89
This commit is contained in:
committed by
Facebook Github Bot
parent
638e4ad8b9
commit
b6cee03e29
@ -48,7 +48,7 @@
|
||||
//
|
||||
// where & and * represent the C-style address-of and indirection operations.
|
||||
|
||||
#include <ATen/native/Indexing.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
158
aten/src/ATen/native/TensorIndexing.cpp
Normal file
158
aten/src/ATen/native/TensorIndexing.cpp
Normal file
@ -0,0 +1,158 @@
|
||||
#include <ATen/native/TensorIndexing.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at {
|
||||
namespace indexing {
|
||||
|
||||
const EllipsisIndexType Ellipsis = EllipsisIndexType();
|
||||
|
||||
Slice::Slice() {}
|
||||
Slice::Slice(int64_t start, int64_t stop, int64_t step) : start_(start), stop_(stop), step_(step) {}
|
||||
|
||||
int64_t Slice::start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
int64_t Slice::stop() const {
|
||||
return stop_;
|
||||
}
|
||||
|
||||
int64_t Slice::step() const {
|
||||
return step_;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
|
||||
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
|
||||
return stream;
|
||||
}
|
||||
|
||||
// This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h
|
||||
Slice unpackSlice(
|
||||
c10::optional<int64_t> start_index = at::indexing::None,
|
||||
c10::optional<int64_t> stop_index = at::indexing::None,
|
||||
c10::optional<int64_t> step_index = at::indexing::None) {
|
||||
int64_t start, stop, step;
|
||||
if (!step_index.has_value()) {
|
||||
step = 1;
|
||||
} else {
|
||||
step = step_index.value();
|
||||
if (step == 0) {
|
||||
TORCH_CHECK(false, "slice step cannot be zero");
|
||||
}
|
||||
// Here step might be -INDEX_MAX-1; in this case we replace it
|
||||
// with -INDEX_MAX. This doesn't affect the semantics, and it
|
||||
// guards against later undefined behaviour resulting from code that
|
||||
// does "step = -step" as part of a slice reversal.
|
||||
if (step < -INDEX_MAX)
|
||||
step = -INDEX_MAX;
|
||||
}
|
||||
if (!start_index.has_value()) {
|
||||
start = step < 0 ? INDEX_MAX : 0;
|
||||
} else {
|
||||
start = start_index.value();
|
||||
}
|
||||
if (!stop_index.has_value()) {
|
||||
stop = step < 0 ? INDEX_MIN : INDEX_MAX;
|
||||
} else {
|
||||
stop = stop_index.value();
|
||||
}
|
||||
return Slice(start, stop, step);
|
||||
}
|
||||
|
||||
TensorIndex::TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
|
||||
TensorIndex::TensorIndex(at::indexing::EllipsisIndexType) : type_(TensorIndexType::Ellipsis) {}
|
||||
TensorIndex::TensorIndex(const char *str) : TensorIndex(at::indexing::Ellipsis) {
|
||||
TORCH_CHECK(
|
||||
strcmp(str, "...") == 0,
|
||||
"Expected \"...\" to represent an ellipsis index, but got \"", str, "\"");
|
||||
}
|
||||
TensorIndex::TensorIndex(int64_t integer) : integer_(integer), type_(TensorIndexType::Integer) {}
|
||||
TensorIndex::TensorIndex(int integer) : TensorIndex((int64_t)integer) {}
|
||||
TensorIndex::TensorIndex(std::initializer_list<c10::optional<int64_t>> init_list)
|
||||
: type_(TensorIndexType::Slice) {
|
||||
if (init_list.size() == 0) {
|
||||
slice_ = unpackSlice();
|
||||
} else if (init_list.size() == 2) {
|
||||
slice_ = unpackSlice(*init_list.begin(), *(init_list.begin() + 1));
|
||||
} else if (init_list.size() == 3) {
|
||||
slice_ = unpackSlice(*init_list.begin(), *(init_list.begin() + 1), *(init_list.begin() + 2));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected 0 / 2 / 3 elements in the braced-init-list to represent a slice index, but got ",
|
||||
init_list.size(),
|
||||
" element(s)");
|
||||
}
|
||||
}
|
||||
TensorIndex::TensorIndex(Tensor tensor) : tensor_(tensor), type_(TensorIndexType::Tensor) {}
|
||||
|
||||
bool TensorIndex::is_none() const {
|
||||
return type_ == TensorIndexType::None;
|
||||
}
|
||||
|
||||
bool TensorIndex::is_ellipsis() const {
|
||||
return type_ == TensorIndexType::Ellipsis;
|
||||
}
|
||||
|
||||
bool TensorIndex::is_integer() const {
|
||||
return type_ == TensorIndexType::Integer;
|
||||
}
|
||||
|
||||
int64_t TensorIndex::integer() const {
|
||||
return integer_;
|
||||
}
|
||||
|
||||
bool TensorIndex::is_boolean() const {
|
||||
return type_ == TensorIndexType::Boolean;
|
||||
}
|
||||
|
||||
bool TensorIndex::boolean() const {
|
||||
return boolean_;
|
||||
}
|
||||
|
||||
bool TensorIndex::is_slice() const {
|
||||
return type_ == TensorIndexType::Slice;
|
||||
}
|
||||
|
||||
const Slice& TensorIndex::slice() const {
|
||||
return slice_;
|
||||
}
|
||||
|
||||
bool TensorIndex::is_tensor() const {
|
||||
return type_ == TensorIndexType::Tensor;
|
||||
}
|
||||
|
||||
const Tensor& TensorIndex::tensor() const {
|
||||
return tensor_;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) {
|
||||
if (tensor_index.is_none()) {
|
||||
stream << "None";
|
||||
} else if (tensor_index.is_ellipsis()) {
|
||||
stream << "...";
|
||||
} else if (tensor_index.is_integer()) {
|
||||
stream << tensor_index.integer();
|
||||
} else if (tensor_index.is_boolean()) {
|
||||
stream << std::boolalpha << tensor_index.boolean();
|
||||
} else if (tensor_index.is_slice()) {
|
||||
stream << tensor_index.slice();
|
||||
} else if (tensor_index.is_tensor()) {
|
||||
stream << tensor_index.tensor();
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
|
||||
stream << "(";
|
||||
for (size_t i = 0; i < tensor_indices.size(); i++) {
|
||||
stream << tensor_indices[i];
|
||||
if (i < tensor_indices.size() - 1) stream << ", ";
|
||||
}
|
||||
stream << ")";
|
||||
return stream;
|
||||
}
|
||||
|
||||
} // namespace indexing
|
||||
} // namespace at
|
||||
112
aten/src/ATen/native/TensorIndexing.h
Normal file
112
aten/src/ATen/native/TensorIndexing.h
Normal file
@ -0,0 +1,112 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <ATen/core/TensorBody.h>
|
||||
|
||||
namespace at {
|
||||
namespace indexing {
|
||||
|
||||
const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max();
|
||||
const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min();
|
||||
|
||||
enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };
|
||||
|
||||
constexpr c10::nullopt_t None{c10::nullopt_t::init()};
|
||||
|
||||
struct CAFFE2_API EllipsisIndexType final { EllipsisIndexType() {} };
|
||||
CAFFE2_API extern const EllipsisIndexType Ellipsis;
|
||||
|
||||
struct CAFFE2_API Slice final {
|
||||
public:
|
||||
Slice();
|
||||
Slice(int64_t start, int64_t stop, int64_t step);
|
||||
|
||||
int64_t start() const;
|
||||
int64_t stop() const;
|
||||
int64_t step() const;
|
||||
|
||||
private:
|
||||
int64_t start_;
|
||||
int64_t stop_;
|
||||
int64_t step_;
|
||||
};
|
||||
|
||||
CAFFE2_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
|
||||
|
||||
// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
|
||||
// `{None, "...", Ellipsis, 0, true, {1, None, 2}, torch::tensor({1, 2})}`
|
||||
// into its equivalent `std::vector<TensorIndex>`, so that further tensor indexing
|
||||
// operations can be performed using the supplied indices.
|
||||
//
|
||||
// There is one-to-one correspondence between Python and C++ tensor index types:
|
||||
// Python | C++
|
||||
// -----------------------------------------------------
|
||||
// `None` | `at::indexing::None`
|
||||
// `Ellipsis` | `at::indexing::Ellipsis`
|
||||
// `...` | `"..."`
|
||||
// `123` | `123`
|
||||
// `True` / `False` | `true` / `false`
|
||||
// `:` | `{}` / `{None, None}`
|
||||
// `::` | `{}` / `{None, None, None}`
|
||||
// `1:` | `{1, None}`
|
||||
// `1::` | `{1, None, None}`
|
||||
// `:3` | `{None, 3}`
|
||||
// `:3:` | `{None, 3, None}`
|
||||
// `::2` | `{None, None, 2}`
|
||||
// `1:3` | `{1, 3}`
|
||||
// `1::2` | `{1, None, 2}`
|
||||
// `:3:2` | `{None, 3, 2}`
|
||||
// `1:3:2` | `{1, 3, 2}`
|
||||
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
|
||||
struct CAFFE2_API TensorIndex final {
|
||||
// Case 1: `at::indexing::None`
|
||||
TensorIndex(c10::nullopt_t);
|
||||
|
||||
// Case 2: "..." / `at::indexing::Ellipsis`
|
||||
TensorIndex(at::indexing::EllipsisIndexType);
|
||||
TensorIndex(const char *str);
|
||||
|
||||
// Case 3: Integer value
|
||||
TensorIndex(int64_t integer);
|
||||
TensorIndex(int integer);
|
||||
|
||||
// Case 4: Boolean value
|
||||
template <class T,
|
||||
class = typename std::enable_if<std::is_same<bool, T>::value>::type >
|
||||
TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
|
||||
|
||||
// Case 5: Slice represented in `{start, stop, step}` form,
|
||||
// where `start` / `stop` / `step` can be integer or `at::indexing::None`
|
||||
TensorIndex(std::initializer_list<c10::optional<int64_t>> init_list);
|
||||
|
||||
// Case 5: Tensor value
|
||||
TensorIndex(Tensor tensor);
|
||||
|
||||
bool is_none() const;
|
||||
bool is_ellipsis() const;
|
||||
|
||||
bool is_integer() const;
|
||||
int64_t integer() const;
|
||||
|
||||
bool is_boolean() const;
|
||||
bool boolean() const;
|
||||
|
||||
bool is_slice() const;
|
||||
const Slice& slice() const;
|
||||
|
||||
bool is_tensor() const;
|
||||
const Tensor& tensor() const;
|
||||
|
||||
private:
|
||||
int64_t integer_;
|
||||
bool boolean_;
|
||||
Slice slice_;
|
||||
Tensor tensor_;
|
||||
TensorIndexType type_;
|
||||
};
|
||||
|
||||
CAFFE2_API std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index);
|
||||
CAFFE2_API std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices);
|
||||
|
||||
} // namespace indexing
|
||||
} // namespace at
|
||||
@ -1,4 +1,4 @@
|
||||
#include <ATen/native/Indexing.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#include <ATen/native/Indexing.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#include <ATen/native/Indexing.h>
|
||||
#include <ATen/native/TensorAdvancedIndexing.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
@ -24,6 +24,7 @@ set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/static.cpp
|
||||
${TORCH_API_TEST_DIR}/support.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_indexing.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor.cpp
|
||||
|
||||
107
test/cpp/api/tensor_indexing.cpp
Normal file
107
test/cpp/api/tensor_indexing.cpp
Normal file
@ -0,0 +1,107 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
// TODO: Move the include into `ATen/ATen.h`, once C++ tensor indexing
|
||||
// is ready to ship.
|
||||
#include <ATen/native/TensorIndexing.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
using namespace torch::indexing;
|
||||
using namespace torch::test;
|
||||
|
||||
TEST(TensorIndexingTest, Slice) {
|
||||
Slice slice(1, 2, 3);
|
||||
ASSERT_EQ(slice.start(), 1);
|
||||
ASSERT_EQ(slice.stop(), 2);
|
||||
ASSERT_EQ(slice.step(), 3);
|
||||
|
||||
ASSERT_EQ(c10::str(slice), "1:2:3");
|
||||
}
|
||||
|
||||
TEST(TensorIndexingTest, TensorIndex) {
|
||||
{
|
||||
std::vector<TensorIndex> indices = {None, "...", Ellipsis, 0, true, {1, None, 2}, torch::tensor({1, 2})};
|
||||
ASSERT_TRUE(indices[0].is_none());
|
||||
ASSERT_TRUE(indices[1].is_ellipsis());
|
||||
ASSERT_TRUE(indices[2].is_ellipsis());
|
||||
ASSERT_TRUE(indices[3].is_integer());
|
||||
ASSERT_TRUE(indices[3].integer() == 0);
|
||||
ASSERT_TRUE(indices[4].is_boolean());
|
||||
ASSERT_TRUE(indices[4].boolean() == true);
|
||||
ASSERT_TRUE(indices[5].is_slice());
|
||||
ASSERT_TRUE(indices[5].slice().start() == 1);
|
||||
ASSERT_TRUE(indices[5].slice().stop() == INDEX_MAX);
|
||||
ASSERT_TRUE(indices[5].slice().step() == 2);
|
||||
ASSERT_TRUE(indices[6].is_tensor());
|
||||
ASSERT_TRUE(torch::equal(indices[6].tensor(), torch::tensor({1, 2})));
|
||||
}
|
||||
|
||||
ASSERT_THROWS_WITH(
|
||||
TensorIndex(".."),
|
||||
"Expected \"...\" to represent an ellipsis index, but got \"..\"");
|
||||
|
||||
// NOTE: Some compilers such as Clang 5 and MSVC always treat `TensorIndex({1})` the same as
|
||||
// `TensorIndex(1)`. This is in violation of the C++ standard
|
||||
// (`https://en.cppreference.com/w/cpp/language/list_initialization`), which says:
|
||||
// ```
|
||||
// copy-list-initialization:
|
||||
//
|
||||
// U( { arg1, arg2, ... } )
|
||||
//
|
||||
// functional cast expression or other constructor invocations, where braced-init-list is used
|
||||
// in place of a constructor argument. Copy-list-initialization initializes the constructor's parameter
|
||||
// (note; the type U in this example is not the type that's being list-initialized; U's constructor's parameter is)
|
||||
// ```
|
||||
// When we call `TensorIndex({1})`, `TensorIndex`'s constructor's parameter is being list-initialized with {1}.
|
||||
// And since we have the `TensorIndex(std::initializer_list<c10::optional<int64_t>>)` constructor, the following
|
||||
// rule in the standard applies:
|
||||
// ```
|
||||
// The effects of list initialization of an object of type T are:
|
||||
//
|
||||
// if T is a specialization of std::initializer_list, the T object is direct-initialized or copy-initialized,
|
||||
// depending on context, from a prvalue of the same type initialized from the braced-init-list.
|
||||
// ```
|
||||
// Therefore, if the compiler strictly follows the standard, it should treat `TensorIndex({1})` as
|
||||
// `TensorIndex(std::initializer_list<c10::optional<int64_t>>({1}))`. However, this is not the case for
|
||||
// compilers such as Clang 5 and MSVC, and hence we skip this test for those compilers.
|
||||
#if (!defined(__clang__) || (defined(__clang__) && __clang_major__ != 5)) && !defined(_MSC_VER)
|
||||
ASSERT_THROWS_WITH(
|
||||
TensorIndex({1}),
|
||||
"Expected 0 / 2 / 3 elements in the braced-init-list to represent a slice index, but got 1 element(s)");
|
||||
#endif
|
||||
|
||||
ASSERT_THROWS_WITH(
|
||||
TensorIndex({1, 2, 3, 4}),
|
||||
"Expected 0 / 2 / 3 elements in the braced-init-list to represent a slice index, but got 4 element(s)");
|
||||
|
||||
{
|
||||
std::vector<TensorIndex> indices = {None, "...", Ellipsis, 0, true, {1, None, 2}};
|
||||
ASSERT_EQ(c10::str(indices), c10::str("(None, ..., ..., 0, true, 1:", INDEX_MAX, ":2)"));
|
||||
ASSERT_EQ(c10::str(indices[0]), "None");
|
||||
ASSERT_EQ(c10::str(indices[1]), "...");
|
||||
ASSERT_EQ(c10::str(indices[2]), "...");
|
||||
ASSERT_EQ(c10::str(indices[3]), "0");
|
||||
ASSERT_EQ(c10::str(indices[4]), "true");
|
||||
ASSERT_EQ(c10::str(indices[5]), c10::str("1:", INDEX_MAX, ":2"));
|
||||
}
|
||||
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{}})), c10::str("(0:", INDEX_MAX, ":1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, None}})), c10::str("(0:", INDEX_MAX, ":1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, None, None}})), c10::str("(0:", INDEX_MAX, ":1)"));
|
||||
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, None}})), c10::str("(1:", INDEX_MAX, ":1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, None, None}})), c10::str("(1:", INDEX_MAX, ":1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, 3}})), c10::str("(0:3:1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, 3, None}})), c10::str("(0:3:1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, None, 2}})), c10::str("(0:", INDEX_MAX, ":2)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, None, -1}})), c10::str("(", INDEX_MAX, ":", INDEX_MIN, ":-1)"));
|
||||
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, 3}})), c10::str("(1:3:1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, None, 2}})), c10::str("(1:", INDEX_MAX, ":2)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, None, -1}})), c10::str("(1:", INDEX_MIN, ":-1)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, 3, 2}})), c10::str("(0:3:2)"));
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{None, 3, -1}})), c10::str("(", INDEX_MAX, ":3:-1)"));
|
||||
|
||||
ASSERT_EQ(c10::str(std::vector<TensorIndex>({{1, 3, 2}})), c10::str("(1:3:2)"));
|
||||
}
|
||||
Reference in New Issue
Block a user